reab5555 commited on
Commit
a27e6f2
·
verified ·
1 Parent(s): 6ab9894

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -36
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import torch
2
  from PIL import Image
3
  import requests
4
- import openai
5
  from transformers import (Owlv2Processor, Owlv2ForObjectDetection,
6
- AutoProcessor, AutoModelForMaskGeneration,
7
- BlipProcessor, BlipForConditionalGeneration)
8
  import matplotlib.pyplot as plt
9
  import matplotlib.patches as patches
10
  import base64
@@ -18,50 +17,198 @@ from dotenv import load_dotenv
18
  # Load environment variables
19
  load_dotenv()
20
  OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
21
- openai.api_key = OPENAI_API_KEY
22
 
23
- def generate_image_caption(image):
24
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
- processor = BlipProcessor.from_pretrained('Salesforce/blip-image-captioning-base')
26
- model = BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base').to(device)
27
 
28
- inputs = processor(image, return_tensors='pt').to(device)
29
- out = model.generate(**inputs)
30
- caption = processor.decode(out[0], skip_special_tokens=True)
31
- return caption
32
 
33
- def analyze_caption(caption):
34
- messages = [
35
- {
36
- "role": "user",
37
- "content": f"""Your task is to determine if the following image description is surprising or not surprising.
38
 
39
- Description: "{caption}"
 
 
40
 
41
- If the description is surprising, determine which element, figure, or object is making it surprising and write it only in one sentence with no more than 6 words; otherwise, write 'NA'.
 
 
42
 
43
- Also, rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
 
 
44
 
45
- Provide the response as a JSON with the following structure:
46
- {{
47
- "label": "[surprising OR not surprising]",
48
- "element": "[element]",
49
- "rating": [1-5]
50
- }}
51
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  }
53
  ]
54
 
55
- response = openai.ChatCompletion.create(
56
- model="gpt-4",
57
  messages=messages,
58
  max_tokens=100,
59
- temperature=0.1
 
 
 
60
  )
61
 
62
  return response.choices[0].message.content
63
 
64
- # The rest of your functions (process_image_detection, show_mask, etc.) remain the same
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def process_and_analyze(image):
67
  if image is None:
@@ -79,11 +226,8 @@ def process_and_analyze(image):
79
  if not isinstance(image, Image.Image):
80
  raise ValueError("Invalid image format")
81
 
82
- # Generate caption
83
- caption = generate_image_caption(image)
84
-
85
- # Analyze caption
86
- gpt_response = analyze_caption(caption)
87
  response_data = json.loads(gpt_response)
88
 
89
  if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
@@ -97,6 +241,7 @@ def process_and_analyze(image):
97
  except Exception as e:
98
  return None, f"Error processing image: {str(e)}"
99
 
 
100
  # Create Gradio interface
101
  def create_interface():
102
  with gr.Blocks() as demo:
 
1
  import torch
2
  from PIL import Image
3
  import requests
4
+ from openai import OpenAI
5
  from transformers import (Owlv2Processor, Owlv2ForObjectDetection,
6
+ AutoProcessor, AutoModelForMaskGeneration)
 
7
  import matplotlib.pyplot as plt
8
  import matplotlib.patches as patches
9
  import base64
 
17
  # Load environment variables
18
  load_dotenv()
19
  OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
 
20
 
 
 
 
 
21
 
22
+ def encode_image_to_base64(image):
23
+ # If image is a tuple (as sometimes provided by Gradio), take the first element
24
+ if isinstance(image, tuple):
25
+ image = image[0]
26
 
27
+ # If image is a numpy array, convert to PIL Image
28
+ if isinstance(image, np.ndarray):
29
+ image = Image.fromarray(image)
 
 
30
 
31
+ # Ensure image is in PIL Image format
32
+ if not isinstance(image, Image.Image):
33
+ raise ValueError("Input must be a PIL Image, numpy array, or tuple containing an image")
34
 
35
+ buffered = io.BytesIO()
36
+ image.save(buffered, format="PNG")
37
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
38
 
39
+ def analyze_image(image):
40
+ client = OpenAI(api_key=OPENAI_API_KEY)
41
+ base64_image = encode_image_to_base64(image)
42
 
43
+ messages = [
44
+ {
45
+ "role": "user",
46
+ "content": [
47
+ {
48
+ "type": "text",
49
+ "text": """Your task is to determine if the image is surprising or not surprising.
50
+ if the image is surprising, determine which element, figure or object in the image is making the image surprising and write it only in one sentence with no more then 6 words, otherwise, write 'NA'.
51
+ Also rate how surprising the image is on a scale of 1-5, where 1 is not surprising at all and 5 is highly surprising.
52
+ Provide the response as a JSON with the following structure:
53
+ {
54
+ "label": "[surprising OR not surprising]",
55
+ "element": "[element]",
56
+ "rating": [1-5]
57
+ }"""
58
+ },
59
+ {
60
+ "type": "image_url",
61
+ "image_url": {
62
+ "url": f"data:image/jpeg;base64,{base64_image}"
63
+ }
64
+ }
65
+ ]
66
  }
67
  ]
68
 
69
+ response = client.chat.completions.create(
70
+ model="gpt-4o-mini",
71
  messages=messages,
72
  max_tokens=100,
73
+ temperature=0.1,
74
+ response_format={
75
+ "type": "json_object"
76
+ }
77
  )
78
 
79
  return response.choices[0].message.content
80
 
81
+
82
+ def show_mask(mask, ax, random_color=False):
83
+ if random_color:
84
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
85
+ else:
86
+ color = np.array([1.0, 0.0, 0.0, 0.5])
87
+
88
+ if len(mask.shape) == 4:
89
+ mask = mask[0, 0]
90
+
91
+ mask_image = np.zeros((*mask.shape, 4), dtype=np.float32)
92
+ mask_image[mask > 0] = color
93
+
94
+ ax.imshow(mask_image)
95
+
96
+
97
+ def process_image_detection(image, target_label, surprise_rating):
98
+ device = "cuda" if torch.cuda.is_available() else "cpu"
99
+
100
+ # Get original image DPI and size
101
+ original_dpi = image.info.get('dpi', (72, 72))
102
+ original_size = image.size
103
+
104
+ # Calculate relative font size based on image dimensions
105
+ base_fontsize = min(original_size) / 40 # Adjust this divisor to change overall font size
106
+
107
+ owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
108
+ owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
109
+
110
+ sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
111
+ sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
112
+
113
+ image_np = np.array(image)
114
+
115
+ inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
116
+ with torch.no_grad():
117
+ outputs = owlv2_model(**inputs)
118
+
119
+ target_sizes = torch.tensor([image.size[::-1]]).to(device)
120
+ results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
121
+
122
+ dpi = 300 # Increased DPI for better text rendering
123
+ figsize = (original_size[0] / dpi, original_size[1] / dpi)
124
+ fig = plt.figure(figsize=figsize, dpi=dpi)
125
+
126
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
127
+ fig.add_axes(ax)
128
+
129
+ plt.imshow(image)
130
+
131
+ scores = results["scores"]
132
+ if len(scores) > 0:
133
+ max_score_idx = scores.argmax().item()
134
+ max_score = scores[max_score_idx].item()
135
+
136
+ if max_score > 0.2:
137
+ box = results["boxes"][max_score_idx].cpu().numpy()
138
+
139
+ sam_inputs = sam_processor(
140
+ image,
141
+ input_boxes=[[[box[0], box[1], box[2], box[3]]]],
142
+ return_tensors="pt"
143
+ ).to(device)
144
+
145
+ with torch.no_grad():
146
+ sam_outputs = sam_model(**sam_inputs)
147
+
148
+ masks = sam_processor.image_processor.post_process_masks(
149
+ sam_outputs.pred_masks.cpu(),
150
+ sam_inputs["original_sizes"].cpu(),
151
+ sam_inputs["reshaped_input_sizes"].cpu()
152
+ )
153
+
154
+ mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
155
+ show_mask(mask, ax=ax)
156
+
157
+ # Draw rectangle with increased line width
158
+ rect = patches.Rectangle(
159
+ (box[0], box[1]),
160
+ box[2] - box[0],
161
+ box[3] - box[1],
162
+ linewidth=max(2, min(original_size) / 500), # Scale line width with image size
163
+ edgecolor='red',
164
+ facecolor='none'
165
+ )
166
+ ax.add_patch(rect)
167
+
168
+ # Add confidence score with improved visibility
169
+ plt.text(
170
+ box[0], box[1] - base_fontsize,
171
+ f'{max_score:.2f}',
172
+ color='red',
173
+ fontsize=base_fontsize,
174
+ fontweight='bold',
175
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
176
+ )
177
+
178
+ # Add label and rating with improved visibility
179
+ plt.text(
180
+ box[2] + base_fontsize / 2, box[1],
181
+ f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
182
+ color='red',
183
+ fontsize=base_fontsize,
184
+ fontweight='bold',
185
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2),
186
+ verticalalignment='bottom'
187
+ )
188
+
189
+ plt.axis('off')
190
+
191
+ # Save with high DPI
192
+ buf = io.BytesIO()
193
+ plt.savefig(buf,
194
+ format='png',
195
+ dpi=dpi,
196
+ bbox_inches='tight',
197
+ pad_inches=0,
198
+ metadata={'dpi': original_dpi})
199
+ buf.seek(0)
200
+ plt.close()
201
+
202
+ # Process final image
203
+ output_image = Image.open(buf)
204
+ output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
205
+
206
+ final_buf = io.BytesIO()
207
+ output_image.save(final_buf, format='PNG', dpi=original_dpi)
208
+ final_buf.seek(0)
209
+
210
+ return final_buf
211
+
212
 
213
  def process_and_analyze(image):
214
  if image is None:
 
226
  if not isinstance(image, Image.Image):
227
  raise ValueError("Invalid image format")
228
 
229
+ # Analyze image
230
+ gpt_response = analyze_image(image)
 
 
 
231
  response_data = json.loads(gpt_response)
232
 
233
  if response_data["label"].lower() == "surprising" and response_data["element"].lower() != "na":
 
241
  except Exception as e:
242
  return None, f"Error processing image: {str(e)}"
243
 
244
+
245
  # Create Gradio interface
246
  def create_interface():
247
  with gr.Blocks() as demo: