ayajoharji commited on
Commit
64c5d24
·
verified ·
1 Parent(s): a9550ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -93
app.py CHANGED
@@ -1,174 +1,340 @@
1
  # app.py
2
 
 
3
  import numpy as np
4
  import gradio as gr
5
  from sklearn.cluster import KMeans
6
- from transformers import (
7
- BlipProcessor,
8
- BlipForConditionalGeneration,
9
- MBartForConditionalGeneration,
10
- MBart50TokenizerFast,
11
- )
12
  from PIL import Image, ImageDraw
13
  import requests
14
  from io import BytesIO
15
 
16
- # Load models globally at startup
17
- print("Loading models...")
18
- blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
19
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
20
- mbart_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
21
- mbart_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
22
- print("Models loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Download example images
25
  def download_example_images():
 
 
 
 
 
 
 
26
  image_urls = [
27
  # URL format: ("Image Description", "Image URL")
28
- ("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470?w=512"),
29
- ("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9?w=512"),
30
- ("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b?w=512"),
31
- ("Beach and Ocean", "https://images.unsplash.com/photo-1507525428034-b723cf961d3e?w=512"),
32
- ("Desert Dunes", "https://images.unsplash.com/photo-1501594907352-04cda38ebc29?w=512"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ]
34
 
35
- example_images = []
36
  for idx, (description, url) in enumerate(image_urls, start=1):
37
  try:
38
  response = requests.get(url)
39
  if response.status_code == 200:
 
40
  img = Image.open(BytesIO(response.content))
41
  img.save(f'example{idx}.jpg')
42
- example_images.append([f'example{idx}.jpg'])
43
  else:
44
  print(f"Failed to download image from {url}")
45
  except Exception as e:
46
  print(f"Exception occurred while downloading image: {e}")
47
- return example_images
48
 
49
  # Download example images and prepare examples list
50
  examples = download_example_images()
51
 
52
- # Load and Process the Entire Image
53
  def load_image(image):
 
 
 
 
 
 
 
 
 
54
  # Convert PIL image to numpy array (RGB)
55
  image_np = np.array(image.convert('RGB'))
56
 
57
- # Resize the image for better processing
58
- resized_image = image.resize((224, 224), resample=Image.LANCZOS)
59
  resized_image_np = np.array(resized_image)
60
 
61
  return resized_image_np
62
 
63
- # Extract Dominant Colors from the Image
64
  def extract_colors(image, k=8):
65
- # Flatten the image
 
 
 
 
 
 
 
 
 
 
66
  pixels = image.reshape(-1, 3)
 
67
  # Normalize pixel values to [0, 1]
68
  pixels = pixels / 255.0
69
- # Ensure data type is float64
70
  pixels = pixels.astype(np.float64)
71
- # Apply K-means clustering to find dominant colors
72
- kmeans = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300)
 
 
 
 
 
 
73
  kmeans.fit(pixels)
 
74
  # Convert normalized colors back to 0-255 scale
75
  colors = (kmeans.cluster_centers_ * 255).astype(int)
76
  return colors
77
 
78
- # Create an Image for the Color Palette
79
  def create_palette_image(colors):
 
 
 
 
 
 
 
 
 
80
  num_colors = len(colors)
81
- palette_height = 50
82
- palette_width = 50 * num_colors
83
- palette_image = Image.new("RGB", (palette_width, palette_height))
 
 
 
84
 
85
  draw = ImageDraw.Draw(palette_image)
86
  for i, color in enumerate(colors):
87
- # Ensure color values are within the valid range and integers
88
  color = tuple(np.clip(color, 0, 255).astype(int))
89
- draw.rectangle([i * 50, 0, (i + 1) * 50, palette_height], fill=color)
 
 
 
 
90
 
91
  return palette_image
92
 
93
- # Display Color Palette as Hex Codes
94
  def display_palette(colors):
 
 
 
 
 
 
 
 
 
95
  hex_colors = []
96
  for color in colors:
97
  # Ensure color values are within valid range and integers
98
  color = np.clip(color, 0, 255).astype(int)
99
- hex_color = "#{:02x}{:02x}{:02x}".format(color[0], color[1], color[2])
 
 
 
 
 
100
  hex_colors.append(hex_color)
101
  return hex_colors
102
 
103
- # Generate Image Caption Using Hugging Face BLIP
104
  def generate_caption(image):
105
- inputs = blip_processor(images=image, return_tensors="pt")
106
- output = blip_model.generate(**inputs)
107
- caption = blip_processor.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
108
  return caption
109
 
110
- # Translate Caption to Arabic Using mBART
111
  def translate_to_arabic(text):
112
- mbart_tokenizer.src_lang = "en_XX"
113
- encoded = mbart_tokenizer(text, return_tensors="pt")
114
- generated_tokens = mbart_model.generate(
115
- **encoded,
116
- forced_bos_token_id=mbart_tokenizer.lang_code_to_id["ar_AR"]
117
- )
118
- translated_text = mbart_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
119
- return translated_text
120
-
121
- # Gradio Interface Function (Combining Elements)
122
- def process_image(image):
123
  try:
124
- # Ensure input is a PIL Image
125
- if isinstance(image, np.ndarray):
126
- image = Image.fromarray(image)
127
-
128
- # Convert to RGB format for PIL processing
129
- image_rgb = image.convert("RGB")
130
-
131
- # Load and resize the entire image
132
- resized_image_np = load_image(image_rgb)
133
-
134
- # Convert resized image to PIL Image for Gradio output
135
- resized_image_pil = Image.fromarray(resized_image_np)
136
-
137
- # Generate caption using BLIP model
138
- caption = generate_caption(image_rgb)
139
-
140
- # Translate caption to Arabic
141
- caption_arabic = translate_to_arabic(caption)
142
-
143
- # Extract dominant colors from the entire image
144
- colors = extract_colors(resized_image_np, k=8)
145
- color_palette = display_palette(colors)
146
 
147
- # Create palette image
148
- palette_image = create_palette_image(colors)
 
 
 
 
 
 
 
 
149
 
150
- # Combine English and Arabic captions
151
- bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
152
-
153
- return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
154
  except Exception as e:
155
- print(f"Error during processing: {e}")
156
- return "An error occurred during processing.", "", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  # Create Gradio Interface using Blocks and add a submit button
159
- with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
160
- gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>")
 
 
 
 
 
 
 
161
  gr.Markdown(
162
  """
163
  <p style='text-align: center;'>
164
- Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.
 
165
  </p>
166
  """
167
  )
168
  with gr.Row():
169
  with gr.Column(scale=1):
170
- image_input = gr.Image(type="pil", label="Upload your image or select an example below")
 
 
 
 
 
171
  submit_button = gr.Button("Submit")
 
172
  gr.Examples(
173
  examples=examples,
174
  inputs=image_input,
@@ -176,15 +342,35 @@ with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
176
  examples_per_page=5,
177
  )
178
  with gr.Column(scale=1):
179
- caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10)
180
- palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2)
181
- palette_image_output = gr.Image(type="pil", label="Color Palette")
182
- resized_image_output = gr.Image(type="pil", label="Resized Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
 
184
  submit_button.click(
185
  fn=process_image,
186
  inputs=image_input,
187
- outputs=[caption_output, palette_hex_output, palette_image_output, resized_image_output],
 
 
 
 
 
188
  )
189
 
190
  # Launch Gradio Interface
 
1
  # app.py
2
 
3
+ # Import necessary libraries
4
  import numpy as np
5
  import gradio as gr
6
  from sklearn.cluster import KMeans
7
+ from transformers import pipeline
 
 
 
 
 
8
  from PIL import Image, ImageDraw
9
  import requests
10
  from io import BytesIO
11
 
12
+ # Suppress warnings for cleaner output
13
+ import warnings
14
+ warnings.filterwarnings('ignore')
15
+
16
+ # Load pipelines globally to avoid reloading on each inference
17
+ print("Loading pipelines...")
18
+
19
+ # Image Captioning Pipeline
20
+ # Using Salesforce/blip-image-captioning-base for generating image captions
21
+ caption_pipeline = pipeline(
22
+ "image-to-text",
23
+ model="Salesforce/blip-image-captioning-base"
24
+ )
25
+
26
+ # Translation Pipeline
27
+ # Using facebook/mbart-large-50-many-to-many-mmt for higher-quality translations
28
+ # This model supports multiple languages and provides better translation quality for Arabic
29
+ translation_pipeline = pipeline(
30
+ "translation",
31
+ model="facebook/mbart-large-50-many-to-many-mmt",
32
+ tokenizer="facebook/mbart-large-50-many-to-many-mmt",
33
+ src_lang="en_XX",
34
+ tgt_lang="ar_AR"
35
+ )
36
+
37
+ print("Pipelines loaded successfully.")
38
 
39
+ # Function to Download Example Images
40
  def download_example_images():
41
+ """
42
+ Downloads example images from provided URLs and saves them locally.
43
+
44
+ Returns:
45
+ examples (list): A list of file paths to the downloaded example images.
46
+ """
47
+ # List of image descriptions and URLs
48
  image_urls = [
49
  # URL format: ("Image Description", "Image URL")
50
+ (
51
+ "Sunset over Mountains",
52
+ "https://images.unsplash.com/photo-1501785888041-af3ef285b470?w=512"
53
+ ),
54
+ (
55
+ "Forest Path",
56
+ "https://images.unsplash.com/photo-1502082553048-f009c37129b9?w=512"
57
+ ),
58
+ (
59
+ "City Skyline",
60
+ "https://images.unsplash.com/photo-1498598453737-8913e843c47b?w=512"
61
+ ),
62
+ (
63
+ "Beach and Ocean",
64
+ "https://images.unsplash.com/photo-1507525428034-b723cf961d3e?w=512"
65
+ ),
66
+ (
67
+ "Desert Dunes",
68
+ "https://images.unsplash.com/photo-1501594907352-04cda38ebc29?w=512"
69
+ ),
70
+ (
71
+ "Snowy Mountain Peak",
72
+ "https://images.unsplash.com/photo-1519608487953-e999c86e7455?w=512"
73
+ ),
74
+ (
75
+ "Autumn Leaves",
76
+ "https://images.unsplash.com/photo-1500530855697-b586d89ba3ee?w=512"
77
+ ),
78
+ (
79
+ "City Street at Night",
80
+ "https://images.unsplash.com/photo-1512453979798-5ea266f8880c?w=512"
81
+ ),
82
+ (
83
+ "Calm Lake Reflection",
84
+ "https://images.unsplash.com/photo-1506744038136-46273834b3fb?w=512"
85
+ ),
86
+ (
87
+ "Lush Green Hills",
88
+ "https://images.unsplash.com/photo-1501696461280-37c52f57e8c1?w=512"
89
+ ),
90
  ]
91
 
92
+ examples = []
93
  for idx, (description, url) in enumerate(image_urls, start=1):
94
  try:
95
  response = requests.get(url)
96
  if response.status_code == 200:
97
+ # Open the image and save it locally
98
  img = Image.open(BytesIO(response.content))
99
  img.save(f'example{idx}.jpg')
100
+ examples.append([f'example{idx}.jpg'])
101
  else:
102
  print(f"Failed to download image from {url}")
103
  except Exception as e:
104
  print(f"Exception occurred while downloading image: {e}")
105
+ return examples
106
 
107
  # Download example images and prepare examples list
108
  examples = download_example_images()
109
 
110
+ # Function to Load and Process Image
111
  def load_image(image):
112
+ """
113
+ Converts the input image to a numpy array and resizes it.
114
+
115
+ Args:
116
+ image (PIL.Image.Image): The input image.
117
+
118
+ Returns:
119
+ resized_image_np (numpy.ndarray): The resized image as a numpy array.
120
+ """
121
  # Convert PIL image to numpy array (RGB)
122
  image_np = np.array(image.convert('RGB'))
123
 
124
+ # Resize the image to (300, 300) for consistent processing
125
+ resized_image = image.resize((300, 300), resample=Image.LANCZOS)
126
  resized_image_np = np.array(resized_image)
127
 
128
  return resized_image_np
129
 
130
+ # Function to Extract Dominant Colors from the Image
131
  def extract_colors(image, k=8):
132
+ """
133
+ Uses KMeans clustering to extract dominant colors from the image.
134
+
135
+ Args:
136
+ image (numpy.ndarray): The input image as a numpy array.
137
+ k (int): The number of clusters (colors) to extract.
138
+
139
+ Returns:
140
+ colors (numpy.ndarray): An array of the dominant colors.
141
+ """
142
+ # Flatten the image to a 2D array of pixels
143
  pixels = image.reshape(-1, 3)
144
+
145
  # Normalize pixel values to [0, 1]
146
  pixels = pixels / 255.0
 
147
  pixels = pixels.astype(np.float64)
148
+
149
+ # Apply KMeans clustering to find dominant colors
150
+ kmeans = KMeans(
151
+ n_clusters=k,
152
+ random_state=0,
153
+ n_init=10,
154
+ max_iter=300
155
+ )
156
  kmeans.fit(pixels)
157
+
158
  # Convert normalized colors back to 0-255 scale
159
  colors = (kmeans.cluster_centers_ * 255).astype(int)
160
  return colors
161
 
162
+ # Function to Create an Image for the Color Palette
163
  def create_palette_image(colors):
164
+ """
165
+ Creates a visual representation of the color palette.
166
+
167
+ Args:
168
+ colors (numpy.ndarray): An array of the dominant colors.
169
+
170
+ Returns:
171
+ palette_image (PIL.Image.Image): The generated color palette image.
172
+ """
173
  num_colors = len(colors)
174
+ palette_height = 100
175
+ palette_width = 100 * num_colors
176
+ palette_image = Image.new(
177
+ "RGB",
178
+ (palette_width, palette_height)
179
+ )
180
 
181
  draw = ImageDraw.Draw(palette_image)
182
  for i, color in enumerate(colors):
183
+ # Ensure color values are within valid range and integers
184
  color = tuple(np.clip(color, 0, 255).astype(int))
185
+ # Draw rectangles for each color
186
+ draw.rectangle(
187
+ [i * 100, 0, (i + 1) * 100, palette_height],
188
+ fill=color
189
+ )
190
 
191
  return palette_image
192
 
193
+ # Function to Display Color Palette as Hex Codes
194
  def display_palette(colors):
195
+ """
196
+ Converts RGB colors to hexadecimal format.
197
+
198
+ Args:
199
+ colors (numpy.ndarray): An array of the dominant colors.
200
+
201
+ Returns:
202
+ hex_colors (list): A list of hex color codes.
203
+ """
204
  hex_colors = []
205
  for color in colors:
206
  # Ensure color values are within valid range and integers
207
  color = np.clip(color, 0, 255).astype(int)
208
+ # Convert to hex code
209
+ hex_color = "#{:02x}{:02x}{:02x}".format(
210
+ color[0],
211
+ color[1],
212
+ color[2]
213
+ )
214
  hex_colors.append(hex_color)
215
  return hex_colors
216
 
217
+ # Function to Generate Image Caption Using Pipeline
218
  def generate_caption(image):
219
+ """
220
+ Generates a caption for the input image using a pre-trained model.
221
+
222
+ Args:
223
+ image (PIL.Image.Image): The input image.
224
+
225
+ Returns:
226
+ caption (str): The generated caption.
227
+ """
228
+ # Use the captioning pipeline to generate a caption
229
+ result = caption_pipeline(image)
230
+ caption = result[0]['generated_text']
231
  return caption
232
 
233
+ # Function to Translate Caption to Arabic Using Pipeline
234
  def translate_to_arabic(text):
235
+ """
236
+ Translates English text to Arabic using a pre-trained model with enhanced post-processing.
237
+
238
+ Args:
239
+ text (str): The English text to translate.
240
+
241
+ Returns:
242
+ translated_text (str): The translated Arabic text.
243
+ """
 
 
244
  try:
245
+ # Use the translation pipeline to translate the text
246
+ result = translation_pipeline(text)
247
+ translated_text = result[0]['translation_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ # Advanced Post-processing to remove repeated words
250
+ # This example uses a simple method; for more robust solutions, consider using NLP libraries
251
+ words = translated_text.split()
252
+ cleaned_words = []
253
+ previous_word = ""
254
+ for word in words:
255
+ if word != previous_word:
256
+ cleaned_words.append(word)
257
+ previous_word = word
258
+ cleaned_translated_text = ' '.join(cleaned_words)
259
 
260
+ return cleaned_translated_text
 
 
 
261
  except Exception as e:
262
+ print(f"Error during translation: {e}")
263
+ return "Translation Error"
264
+
265
+ # Gradio Interface Function (Combining All Elements)
266
+ def process_image(image):
267
+ """
268
+ Processes the input image to generate a bilingual caption and color palette.
269
+
270
+ Args:
271
+ image (PIL.Image.Image or numpy.ndarray): The input image.
272
+
273
+ Returns:
274
+ tuple: Contains bilingual caption, hex color codes, palette image, and resized image.
275
+ """
276
+ # Ensure input is a PIL Image
277
+ if isinstance(image, np.ndarray):
278
+ image = Image.fromarray(image)
279
+
280
+ # Convert to RGB format
281
+ image_rgb = image.convert("RGB")
282
+
283
+ # Load and resize the image
284
+ resized_image_np = load_image(image_rgb)
285
+ resized_image_pil = Image.fromarray(resized_image_np)
286
+
287
+ # Generate caption using the caption pipeline
288
+ caption = generate_caption(image_rgb)
289
+
290
+ # Translate caption to Arabic using the translation pipeline
291
+ caption_arabic = translate_to_arabic(caption)
292
+
293
+ # Extract dominant colors from the image
294
+ colors = extract_colors(resized_image_np, k=8)
295
+ color_palette = display_palette(colors)
296
+
297
+ # Create palette image
298
+ palette_image = create_palette_image(colors)
299
+
300
+ # Combine English and Arabic captions
301
+ bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
302
+
303
+ return (
304
+ bilingual_caption,
305
+ ", ".join(color_palette),
306
+ palette_image,
307
+ resized_image_pil
308
+ )
309
 
310
  # Create Gradio Interface using Blocks and add a submit button
311
+ with gr.Blocks(
312
+ css=".gradio-container { height: 1000px !important; }"
313
+ ) as demo:
314
+ # Title and Description
315
+ gr.Markdown(
316
+ "<h1 style='text-align: center;'>"
317
+ "Palette Generator from Image with Image Captioning"
318
+ "</h1>"
319
+ )
320
  gr.Markdown(
321
  """
322
  <p style='text-align: center;'>
323
+ Upload an image or select one of the example images below to generate
324
+ a color palette and a description of the image in both English and Arabic.
325
  </p>
326
  """
327
  )
328
  with gr.Row():
329
  with gr.Column(scale=1):
330
+ # Image Input Component
331
+ image_input = gr.Image(
332
+ type="pil",
333
+ label="Upload your image or select an example below"
334
+ )
335
+ # Submit Button
336
  submit_button = gr.Button("Submit")
337
+ # Examples Component
338
  gr.Examples(
339
  examples=examples,
340
  inputs=image_input,
 
342
  examples_per_page=5,
343
  )
344
  with gr.Column(scale=1):
345
+ # Output Components
346
+ caption_output = gr.Textbox(
347
+ label="Bilingual Caption",
348
+ lines=5,
349
+ max_lines=10
350
+ )
351
+ palette_hex_output = gr.Textbox(
352
+ label="Color Palette Hex Codes",
353
+ lines=2
354
+ )
355
+ palette_image_output = gr.Image(
356
+ type="pil",
357
+ label="Color Palette"
358
+ )
359
+ resized_image_output = gr.Image(
360
+ type="pil",
361
+ label="Resized Image"
362
+ )
363
 
364
+ # Define the action on submit button click
365
  submit_button.click(
366
  fn=process_image,
367
  inputs=image_input,
368
+ outputs=[
369
+ caption_output,
370
+ palette_hex_output,
371
+ palette_image_output,
372
+ resized_image_output
373
+ ],
374
  )
375
 
376
  # Launch Gradio Interface