ayajoharji commited on
Commit
521a5df
·
verified ·
1 Parent(s): b1ed444

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -81
app.py CHANGED
@@ -11,10 +11,16 @@ from PIL import Image, ImageDraw
11
  import requests
12
  from io import BytesIO
13
 
 
 
 
 
 
 
 
14
  # Download example images
15
  def download_example_images():
16
  image_urls = [
17
- # URL format: ("Image Description", "Image URL")
18
  ("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"),
19
  ("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"),
20
  ("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"),
@@ -24,13 +30,14 @@ def download_example_images():
24
 
25
  example_images = []
26
  for idx, (description, url) in enumerate(image_urls, start=1):
27
- response = requests.get(url)
28
- if response.status_code == 200:
 
29
  img = Image.open(BytesIO(response.content))
30
  img.save(f'example{idx}.jpg')
31
  example_images.append([f'example{idx}.jpg'])
32
- else:
33
- print(f"Failed to download image from {url}")
34
  return example_images
35
 
36
  # Download example images and prepare examples list
@@ -38,140 +45,77 @@ examples = download_example_images()
38
 
39
  # Load and Process the Entire Image
40
  def load_image(image):
41
- # Convert PIL image to numpy array (RGB)
42
- image_np = np.array(image.convert('RGB'))
43
-
44
- # Resize the image for better processing
45
  resized_image = image.resize((300, 300), resample=Image.LANCZOS)
46
- resized_image_np = np.array(resized_image)
47
-
48
- return resized_image_np
49
 
50
  # Extract Dominant Colors from the Image
51
  def extract_colors(image, k=8):
52
- # Flatten the image
53
- pixels = image.reshape(-1, 3)
54
- # Normalize pixel values to [0, 1]
55
- pixels = pixels / 255.0
56
- # Ensure data type is float64
57
- pixels = pixels.astype(np.float64)
58
- # Apply K-means clustering to find dominant colors
59
  kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300)
60
  kmeans.fit(pixels)
61
- # Convert normalized colors back to 0-255 scale
62
- colors = (kmeans.cluster_centers_ * 255).astype(int)
63
- return colors
64
 
65
  # Create an Image for the Color Palette
66
  def create_palette_image(colors):
67
  num_colors = len(colors)
68
- palette_height = 100
69
- palette_width = 100 * num_colors
70
- palette_image = Image.new("RGB", (palette_width, palette_height))
71
-
72
  draw = ImageDraw.Draw(palette_image)
73
  for i, color in enumerate(colors):
74
- # Ensure color values are within the valid range and integers
75
  color = tuple(np.clip(color, 0, 255).astype(int))
76
- draw.rectangle([i * 100, 0, (i + 1) * 100, palette_height], fill=color)
77
-
78
  return palette_image
79
 
80
  # Display Color Palette as Hex Codes
81
  def display_palette(colors):
82
- hex_colors = []
83
- for color in colors:
84
- # Ensure color values are within valid range and integers
85
- color = np.clip(color, 0, 255).astype(int)
86
- hex_color = "#{:02x}{:02x}{:02x}".format(color[0], color[1], color[2])
87
- hex_colors.append(hex_color)
88
- return hex_colors
89
 
90
  # Generate Image Caption Using Hugging Face BLIP
91
  def generate_caption(image):
92
- # Load models only once
93
- if 'processor' not in generate_caption.__dict__:
94
- generate_caption.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
95
- generate_caption.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
96
- processor = generate_caption.processor
97
- model = generate_caption.model
98
-
99
  inputs = processor(images=image, return_tensors="pt")
100
- output = model.generate(**inputs)
101
- caption = processor.decode(output[0], skip_special_tokens=True)
102
- return caption
103
 
104
  # Translate Caption to Arabic Using mBART
105
  def translate_to_arabic(text):
106
- # Load models only once
107
- if 'tokenizer' not in translate_to_arabic.__dict__:
108
- translate_to_arabic.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
109
- translate_to_arabic.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
110
- tokenizer = translate_to_arabic.tokenizer
111
- model = translate_to_arabic.model
112
-
113
  tokenizer.src_lang = "en_XX"
114
  encoded = tokenizer(text, return_tensors="pt")
115
- generated_tokens = model.generate(
116
  **encoded,
117
  forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"]
118
  )
119
- translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
120
- return translated_text
121
 
122
  # Gradio Interface Function (Combining Elements)
123
  def process_image(image):
124
- # Ensure input is a PIL Image
125
  if isinstance(image, np.ndarray):
126
  image = Image.fromarray(image)
127
 
128
- # Convert to RGB format for PIL processing
129
  image_rgb = image.convert("RGB")
130
-
131
- # Load and resize the entire image
132
  resized_image_np = load_image(image_rgb)
133
-
134
- # Convert resized image to PIL Image for Gradio output
135
  resized_image_pil = Image.fromarray(resized_image_np)
136
 
137
- # Generate caption using BLIP model
138
  caption = generate_caption(image_rgb)
139
-
140
- # Translate caption to Arabic
141
  caption_arabic = translate_to_arabic(caption)
142
 
143
- # Extract dominant colors from the entire image
144
  colors = extract_colors(resized_image_np, k=8)
145
  color_palette = display_palette(colors)
146
-
147
- # Create palette image
148
  palette_image = create_palette_image(colors)
149
 
150
- # Combine English and Arabic captions
151
  bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
152
 
153
  return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
154
 
155
- # Create Gradio Interface using Blocks and add a submit button
156
  with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
157
  gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>")
158
  gr.Markdown(
159
- """
160
- <p style='text-align: center;'>
161
- Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.
162
- </p>
163
- """
164
  )
165
  with gr.Row():
166
  with gr.Column(scale=1):
167
  image_input = gr.Image(type="pil", label="Upload your image or select an example below")
168
  submit_button = gr.Button("Submit")
169
- gr.Examples(
170
- examples=examples,
171
- inputs=image_input,
172
- label="Example Images",
173
- examples_per_page=5,
174
- )
175
  with gr.Column(scale=1):
176
  caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10)
177
  palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2)
@@ -186,3 +130,4 @@ with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
186
 
187
  # Launch Gradio Interface
188
  demo.launch()
 
 
11
  import requests
12
  from io import BytesIO
13
 
14
+ # Load Hugging Face models globally
15
+ print("Loading Hugging Face models...")
16
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
17
+ caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
18
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
19
+ translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
20
+
21
  # Download example images
22
  def download_example_images():
23
  image_urls = [
 
24
  ("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"),
25
  ("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"),
26
  ("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"),
 
30
 
31
  example_images = []
32
  for idx, (description, url) in enumerate(image_urls, start=1):
33
+ try:
34
+ response = requests.get(url)
35
+ response.raise_for_status()
36
  img = Image.open(BytesIO(response.content))
37
  img.save(f'example{idx}.jpg')
38
  example_images.append([f'example{idx}.jpg'])
39
+ except requests.RequestException as e:
40
+ print(f"Failed to download image from {url}: {e}")
41
  return example_images
42
 
43
  # Download example images and prepare examples list
 
45
 
46
  # Load and Process the Entire Image
47
  def load_image(image):
 
 
 
 
48
  resized_image = image.resize((300, 300), resample=Image.LANCZOS)
49
+ return np.array(resized_image)
 
 
50
 
51
  # Extract Dominant Colors from the Image
52
  def extract_colors(image, k=8):
53
+ pixels = image.reshape(-1, 3) / 255.0
 
 
 
 
 
 
54
  kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300)
55
  kmeans.fit(pixels)
56
+ return (kmeans.cluster_centers_ * 255).astype(int)
 
 
57
 
58
  # Create an Image for the Color Palette
59
  def create_palette_image(colors):
60
  num_colors = len(colors)
61
+ palette_image = Image.new("RGB", (100 * num_colors, 100))
 
 
 
62
  draw = ImageDraw.Draw(palette_image)
63
  for i, color in enumerate(colors):
 
64
  color = tuple(np.clip(color, 0, 255).astype(int))
65
+ draw.rectangle([i * 100, 0, (i + 1) * 100, 100], fill=color)
 
66
  return palette_image
67
 
68
  # Display Color Palette as Hex Codes
69
  def display_palette(colors):
70
+ return ["#{:02x}{:02x}{:02x}".format(*np.clip(color, 0, 255)) for color in colors]
 
 
 
 
 
 
71
 
72
  # Generate Image Caption Using Hugging Face BLIP
73
  def generate_caption(image):
 
 
 
 
 
 
 
74
  inputs = processor(images=image, return_tensors="pt")
75
+ output = caption_model.generate(**inputs)
76
+ return processor.decode(output[0], skip_special_tokens=True)
 
77
 
78
  # Translate Caption to Arabic Using mBART
79
  def translate_to_arabic(text):
 
 
 
 
 
 
 
80
  tokenizer.src_lang = "en_XX"
81
  encoded = tokenizer(text, return_tensors="pt")
82
+ generated_tokens = translation_model.generate(
83
  **encoded,
84
  forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"]
85
  )
86
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
87
 
88
  # Gradio Interface Function (Combining Elements)
89
  def process_image(image):
 
90
  if isinstance(image, np.ndarray):
91
  image = Image.fromarray(image)
92
 
 
93
  image_rgb = image.convert("RGB")
 
 
94
  resized_image_np = load_image(image_rgb)
 
 
95
  resized_image_pil = Image.fromarray(resized_image_np)
96
 
 
97
  caption = generate_caption(image_rgb)
 
 
98
  caption_arabic = translate_to_arabic(caption)
99
 
 
100
  colors = extract_colors(resized_image_np, k=8)
101
  color_palette = display_palette(colors)
 
 
102
  palette_image = create_palette_image(colors)
103
 
 
104
  bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
105
 
106
  return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
107
 
108
+ # Create Gradio Interface
109
  with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
110
  gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>")
111
  gr.Markdown(
112
+ "<p style='text-align: center;'>Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.</p>"
 
 
 
 
113
  )
114
  with gr.Row():
115
  with gr.Column(scale=1):
116
  image_input = gr.Image(type="pil", label="Upload your image or select an example below")
117
  submit_button = gr.Button("Submit")
118
+ gr.Examples(examples=examples, inputs=image_input, label="Example Images", examples_per_page=5)
 
 
 
 
 
119
  with gr.Column(scale=1):
120
  caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10)
121
  palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2)
 
130
 
131
  # Launch Gradio Interface
132
  demo.launch()
133
+