reab5555 commited on
Commit
5e351cb
·
verified ·
1 Parent(s): 1f43f91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -204
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import torch
2
  from PIL import Image
3
  import requests
4
- from openai import OpenAI
5
  from transformers import (Owlv2Processor, Owlv2ForObjectDetection,
6
- AutoProcessor, AutoModelForMaskGeneration)
 
7
  import matplotlib.pyplot as plt
8
  import matplotlib.patches as patches
9
  import base64
@@ -17,199 +18,50 @@ from dotenv import load_dotenv
17
  # Load environment variables
18
  load_dotenv()
19
  OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
 
20
 
 
 
 
 
21
 
22
- def encode_image_to_base64(image):
23
- # If image is a tuple (e.g., Gradio input), take the first element
24
- if isinstance(image, tuple):
25
- image = image[0] # Extract the image from the tuple
26
 
27
- # If image is a numpy array, convert it to a PIL Image
28
- if isinstance(image, np.ndarray):
29
- image = Image.fromarray(image)
30
-
31
- # Ensure image is in PIL Image format
32
- if not isinstance(image, Image.Image):
33
- raise ValueError("Input must be a PIL Image, numpy array, or tuple containing an image")
34
 
35
- buffered = io.BytesIO()
36
- image.save(buffered, format="PNG")
37
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
38
 
 
39
 
40
- def analyze_image(image):
41
- client = OpenAI(api_key=OPENAI_API_KEY)
42
- base64_image = encode_image_to_base64(image)
43
 
44
- messages = [
45
- {
46
- "role": "user",
47
- "content": [
48
- {
49
- "type": "text",
50
- "text": """Your task is to determine if the image is surprising or not surprising.
51
- if the image is surprising, determine which element, figure or object in the image is making the image surprising and write it only in one sentence with no more then 6 words, otherwise, write 'NA'.
52
- Also rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
53
- Provide the response as a JSON with the following structure:
54
- {
55
- "label": "[surprising OR not surprising]",
56
- "element": "[element]",
57
- "rating": [1-5]
58
- }"""
59
- },
60
- {
61
- "type": "image_url",
62
- "image_url": {
63
- "url": f"data:image/jpeg;base64,{base64_image}"
64
- }
65
- }
66
- ]
67
  }
68
  ]
69
 
70
- response = client.chat.completions.create(
71
- model="gpt-4o-mini",
72
  messages=messages,
73
  max_tokens=100,
74
- temperature=0.1,
75
- response_format={
76
- "type": "json_object"
77
- }
78
  )
79
 
80
  return response.choices[0].message.content
81
 
82
-
83
- def show_mask(mask, ax, random_color=False):
84
- if random_color:
85
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
86
- else:
87
- color = np.array([1.0, 0.0, 0.0, 0.5])
88
-
89
- if len(mask.shape) == 4:
90
- mask = mask[0, 0]
91
-
92
- mask_image = np.zeros((*mask.shape, 4), dtype=np.float32)
93
- mask_image[mask > 0] = color
94
-
95
- ax.imshow(mask_image)
96
-
97
-
98
- def process_image_detection(image, target_label, surprise_rating):
99
- device = "cuda" if torch.cuda.is_available() else "cpu"
100
-
101
- # Get original image DPI and size
102
- original_dpi = image.info.get('dpi', (72, 72))
103
- original_size = image.size
104
-
105
- # Calculate relative font size based on image dimensions
106
- base_fontsize = min(original_size) / 40 # Adjust this divisor to change overall font size
107
-
108
- owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
109
- owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
110
-
111
- sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
112
- sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
113
-
114
- image_np = np.array(image)
115
-
116
- inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
117
- with torch.no_grad():
118
- outputs = owlv2_model(**inputs)
119
-
120
- target_sizes = torch.tensor([image.size[::-1]]).to(device)
121
- results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
122
-
123
- dpi = 300 # Increased DPI for better text rendering
124
- figsize = (original_size[0] / dpi, original_size[1] / dpi)
125
- fig = plt.figure(figsize=figsize, dpi=dpi)
126
-
127
- ax = plt.Axes(fig, [0., 0., 1., 1.])
128
- fig.add_axes(ax)
129
-
130
- plt.imshow(image)
131
-
132
- scores = results["scores"]
133
- if len(scores) > 0:
134
- max_score_idx = scores.argmax().item()
135
- max_score = scores[max_score_idx].item()
136
-
137
- if max_score > 0.2:
138
- box = results["boxes"][max_score_idx].cpu().numpy()
139
-
140
- sam_inputs = sam_processor(
141
- image,
142
- input_boxes=[[[box[0], box[1], box[2], box[3]]]],
143
- return_tensors="pt"
144
- ).to(device)
145
-
146
- with torch.no_grad():
147
- sam_outputs = sam_model(**sam_inputs)
148
-
149
- masks = sam_processor.image_processor.post_process_masks(
150
- sam_outputs.pred_masks.cpu(),
151
- sam_inputs["original_sizes"].cpu(),
152
- sam_inputs["reshaped_input_sizes"].cpu()
153
- )
154
-
155
- mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
156
- show_mask(mask, ax=ax)
157
-
158
- # Draw rectangle with increased line width
159
- rect = patches.Rectangle(
160
- (box[0], box[1]),
161
- box[2] - box[0],
162
- box[3] - box[1],
163
- linewidth=max(2, min(original_size) / 500), # Scale line width with image size
164
- edgecolor='red',
165
- facecolor='none'
166
- )
167
- ax.add_patch(rect)
168
-
169
- # Add confidence score with improved visibility
170
- plt.text(
171
- box[0], box[1] - base_fontsize,
172
- f'{max_score:.2f}',
173
- color='red',
174
- fontsize=base_fontsize,
175
- fontweight='bold',
176
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
177
- )
178
-
179
- # Add label and rating with improved visibility
180
- plt.text(
181
- box[2] + base_fontsize / 2, box[1],
182
- f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
183
- color='red',
184
- fontsize=base_fontsize,
185
- fontweight='bold',
186
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2),
187
- verticalalignment='bottom'
188
- )
189
-
190
- plt.axis('off')
191
-
192
- # Save with high DPI
193
- buf = io.BytesIO()
194
- plt.savefig(buf,
195
- format='png',
196
- dpi=dpi,
197
- bbox_inches='tight',
198
- pad_inches=0,
199
- metadata={'dpi': original_dpi})
200
- buf.seek(0)
201
- plt.close()
202
-
203
- # Process final image
204
- output_image = Image.open(buf)
205
- output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
206
-
207
- final_buf = io.BytesIO()
208
- output_image.save(final_buf, format='PNG', dpi=original_dpi)
209
- final_buf.seek(0)
210
-
211
- return final_buf
212
-
213
 
214
  def process_and_analyze(image):
215
  if image is None:
@@ -227,8 +79,11 @@ def process_and_analyze(image):
227
  if not isinstance(image, Image.Image):
228
  raise ValueError("Invalid image format")
229
 
230
- # Analyze image
231
- gpt_response = analyze_image(image)
 
 
 
232
  response_data = json.loads(gpt_response)
233
 
234
  if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
@@ -242,29 +97,7 @@ def process_and_analyze(image):
242
  except Exception as e:
243
  return None, f"Error processing image: {str(e)}"
244
 
245
-
246
- # Create Gradio interface
247
- def create_interface():
248
- with gr.Blocks() as demo:
249
- gr.Markdown("# Image Surprise Analysis")
250
-
251
- with gr.Row():
252
- with gr.Column():
253
- input_image = gr.Image(label="Upload Image")
254
- analyze_btn = gr.Button("Analyze Image")
255
-
256
- with gr.Column():
257
- output_image = gr.Image(label="Processed Image")
258
- output_text = gr.Textbox(label="Analysis Results")
259
-
260
- analyze_btn.click(
261
- fn=process_and_analyze,
262
- inputs=[input_image],
263
- outputs=[output_image, output_text]
264
- )
265
-
266
- return demo
267
-
268
 
269
  if __name__ == "__main__":
270
  demo = create_interface()
 
1
  import torch
2
  from PIL import Image
3
  import requests
4
+ import openai
5
  from transformers import (Owlv2Processor, Owlv2ForObjectDetection,
6
+ AutoProcessor, AutoModelForMaskGeneration,
7
+ BlipProcessor, BlipForConditionalGeneration)
8
  import matplotlib.pyplot as plt
9
  import matplotlib.patches as patches
10
  import base64
 
18
  # Load environment variables
19
  load_dotenv()
20
  OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
21
+ openai.api_key = OPENAI_API_KEY
22
 
23
+ def generate_image_caption(image):
24
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
+ processor = BlipProcessor.from_pretrained('Salesforce/blip-image-captioning-base')
26
+ model = BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base').to(device)
27
 
28
+ inputs = processor(image, return_tensors='pt').to(device)
29
+ out = model.generate(**inputs)
30
+ caption = processor.decode(out[0], skip_special_tokens=True)
31
+ return caption
32
 
33
+ def analyze_caption(caption):
34
+ messages = [
35
+ {
36
+ "role": "user",
37
+ "content": f"""Your task is to determine if the following image description is surprising or not surprising.
 
 
38
 
39
+ Description: "{caption}"
 
 
40
 
41
+ If the description is surprising, determine which element, figure, or object is making it surprising and write it only in one sentence with no more than 6 words; otherwise, write 'NA'.
42
 
43
+ Also, rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
 
 
44
 
45
+ Provide the response as a JSON with the following structure:
46
+ {{
47
+ "label": "[surprising OR not surprising]",
48
+ "element": "[element]",
49
+ "rating": [1-5]
50
+ }}
51
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  }
53
  ]
54
 
55
+ response = openai.ChatCompletion.create(
56
+ model="gpt-4",
57
  messages=messages,
58
  max_tokens=100,
59
+ temperature=0.1
 
 
 
60
  )
61
 
62
  return response.choices[0].message.content
63
 
64
+ # The rest of your functions (process_image_detection, show_mask, etc.) remain the same
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def process_and_analyze(image):
67
  if image is None:
 
79
  if not isinstance(image, Image.Image):
80
  raise ValueError("Invalid image format")
81
 
82
+ # Generate caption
83
+ caption = generate_image_caption(image)
84
+
85
+ # Analyze caption
86
+ gpt_response = analyze_caption(caption)
87
  response_data = json.loads(gpt_response)
88
 
89
  if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
 
97
  except Exception as e:
98
  return None, f"Error processing image: {str(e)}"
99
 
100
+ # Create Gradio interface remains the same
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  if __name__ == "__main__":
103
  demo = create_interface()