ayajoharji commited on
Commit
7df98bf
·
verified ·
1 Parent(s): 35d0b13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -27
app.py CHANGED
@@ -1,5 +1,4 @@
1
- # Install necessary libraries
2
- pip install torch transformers gradio Pillow scikit-learn requests
3
 
4
  import numpy as np
5
  import gradio as gr
@@ -13,16 +12,11 @@ from transformers import (
13
  from PIL import Image, ImageDraw
14
  import requests
15
  from io import BytesIO
16
- # Load Hugging Face models globally
17
- print("Loading Hugging Face models...")
18
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
19
- caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
20
- tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
21
- translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
22
 
23
  # Download example images
24
  def download_example_images():
25
  image_urls = [
 
26
  ("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"),
27
  ("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"),
28
  ("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"),
@@ -32,14 +26,13 @@ def download_example_images():
32
 
33
  example_images = []
34
  for idx, (description, url) in enumerate(image_urls, start=1):
35
- try:
36
- response = requests.get(url)
37
- response.raise_for_status()
38
  img = Image.open(BytesIO(response.content))
39
  img.save(f'example{idx}.jpg')
40
  example_images.append([f'example{idx}.jpg'])
41
- except requests.RequestException as e:
42
- print(f"Failed to download image from {url}: {e}")
43
  return example_images
44
 
45
  # Download example images and prepare examples list
@@ -47,77 +40,140 @@ examples = download_example_images()
47
 
48
  # Load and Process the Entire Image
49
  def load_image(image):
 
 
 
 
50
  resized_image = image.resize((300, 300), resample=Image.LANCZOS)
51
- return np.array(resized_image)
 
 
52
 
53
  # Extract Dominant Colors from the Image
54
  def extract_colors(image, k=8):
55
- pixels = image.reshape(-1, 3) / 255.0
 
 
 
 
 
 
56
  kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300)
57
  kmeans.fit(pixels)
58
- return (kmeans.cluster_centers_ * 255).astype(int)
 
 
59
 
60
  # Create an Image for the Color Palette
61
  def create_palette_image(colors):
62
  num_colors = len(colors)
63
- palette_image = Image.new("RGB", (100 * num_colors, 100))
 
 
 
64
  draw = ImageDraw.Draw(palette_image)
65
  for i, color in enumerate(colors):
 
66
  color = tuple(np.clip(color, 0, 255).astype(int))
67
- draw.rectangle([i * 100, 0, (i + 1) * 100, 100], fill=color)
 
68
  return palette_image
69
 
70
  # Display Color Palette as Hex Codes
71
  def display_palette(colors):
72
- return ["#{:02x}{:02x}{:02x}".format(*np.clip(color, 0, 255)) for color in colors]
 
 
 
 
 
 
73
 
74
  # Generate Image Caption Using Hugging Face BLIP
75
  def generate_caption(image):
 
 
 
 
 
 
 
76
  inputs = processor(images=image, return_tensors="pt")
77
- output = caption_model.generate(**inputs)
78
- return processor.decode(output[0], skip_special_tokens=True)
 
79
 
80
  # Translate Caption to Arabic Using mBART
81
  def translate_to_arabic(text):
 
 
 
 
 
 
 
82
  tokenizer.src_lang = "en_XX"
83
  encoded = tokenizer(text, return_tensors="pt")
84
- generated_tokens = translation_model.generate(
85
  **encoded,
86
  forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"]
87
  )
88
- return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
89
 
90
  # Gradio Interface Function (Combining Elements)
91
  def process_image(image):
 
92
  if isinstance(image, np.ndarray):
93
  image = Image.fromarray(image)
94
 
 
95
  image_rgb = image.convert("RGB")
 
 
96
  resized_image_np = load_image(image_rgb)
 
 
97
  resized_image_pil = Image.fromarray(resized_image_np)
98
 
 
99
  caption = generate_caption(image_rgb)
 
 
100
  caption_arabic = translate_to_arabic(caption)
101
 
 
102
  colors = extract_colors(resized_image_np, k=8)
103
  color_palette = display_palette(colors)
 
 
104
  palette_image = create_palette_image(colors)
105
 
 
106
  bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
107
 
108
  return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
109
 
110
- # Create Gradio Interface
111
  with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
112
  gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>")
113
  gr.Markdown(
114
- "<p style='text-align: center;'>Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.</p>"
 
 
 
 
115
  )
116
  with gr.Row():
117
  with gr.Column(scale=1):
118
  image_input = gr.Image(type="pil", label="Upload your image or select an example below")
119
  submit_button = gr.Button("Submit")
120
- gr.Examples(examples=examples, inputs=image_input, label="Example Images", examples_per_page=5)
 
 
 
 
 
121
  with gr.Column(scale=1):
122
  caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10)
123
  palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2)
@@ -132,4 +188,3 @@ with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
132
 
133
  # Launch Gradio Interface
134
  demo.launch()
135
-
 
1
+ # app.py
 
2
 
3
  import numpy as np
4
  import gradio as gr
 
12
  from PIL import Image, ImageDraw
13
  import requests
14
  from io import BytesIO
 
 
 
 
 
 
15
 
16
  # Download example images
17
  def download_example_images():
18
  image_urls = [
19
+ # URL format: ("Image Description", "Image URL")
20
  ("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"),
21
  ("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"),
22
  ("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"),
 
26
 
27
  example_images = []
28
  for idx, (description, url) in enumerate(image_urls, start=1):
29
+ response = requests.get(url)
30
+ if response.status_code == 200:
 
31
  img = Image.open(BytesIO(response.content))
32
  img.save(f'example{idx}.jpg')
33
  example_images.append([f'example{idx}.jpg'])
34
+ else:
35
+ print(f"Failed to download image from {url}")
36
  return example_images
37
 
38
  # Download example images and prepare examples list
 
40
 
41
  # Load and Process the Entire Image
42
  def load_image(image):
43
+ # Convert PIL image to numpy array (RGB)
44
+ image_np = np.array(image.convert('RGB'))
45
+
46
+ # Resize the image for better processing
47
  resized_image = image.resize((300, 300), resample=Image.LANCZOS)
48
+ resized_image_np = np.array(resized_image)
49
+
50
+ return resized_image_np
51
 
52
  # Extract Dominant Colors from the Image
53
  def extract_colors(image, k=8):
54
+ # Flatten the image
55
+ pixels = image.reshape(-1, 3)
56
+ # Normalize pixel values to [0, 1]
57
+ pixels = pixels / 255.0
58
+ # Ensure data type is float64
59
+ pixels = pixels.astype(np.float64)
60
+ # Apply K-means clustering to find dominant colors
61
  kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300)
62
  kmeans.fit(pixels)
63
+ # Convert normalized colors back to 0-255 scale
64
+ colors = (kmeans.cluster_centers_ * 255).astype(int)
65
+ return colors
66
 
67
  # Create an Image for the Color Palette
68
  def create_palette_image(colors):
69
  num_colors = len(colors)
70
+ palette_height = 100
71
+ palette_width = 100 * num_colors
72
+ palette_image = Image.new("RGB", (palette_width, palette_height))
73
+
74
  draw = ImageDraw.Draw(palette_image)
75
  for i, color in enumerate(colors):
76
+ # Ensure color values are within the valid range and integers
77
  color = tuple(np.clip(color, 0, 255).astype(int))
78
+ draw.rectangle([i * 100, 0, (i + 1) * 100, palette_height], fill=color)
79
+
80
  return palette_image
81
 
82
  # Display Color Palette as Hex Codes
83
  def display_palette(colors):
84
+ hex_colors = []
85
+ for color in colors:
86
+ # Ensure color values are within valid range and integers
87
+ color = np.clip(color, 0, 255).astype(int)
88
+ hex_color = "#{:02x}{:02x}{:02x}".format(color[0], color[1], color[2])
89
+ hex_colors.append(hex_color)
90
+ return hex_colors
91
 
92
  # Generate Image Caption Using Hugging Face BLIP
93
  def generate_caption(image):
94
+ # Load models only once
95
+ if 'processor' not in generate_caption.__dict__:
96
+ generate_caption.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
97
+ generate_caption.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
98
+ processor = generate_caption.processor
99
+ model = generate_caption.model
100
+
101
  inputs = processor(images=image, return_tensors="pt")
102
+ output = model.generate(**inputs)
103
+ caption = processor.decode(output[0], skip_special_tokens=True)
104
+ return caption
105
 
106
  # Translate Caption to Arabic Using mBART
107
  def translate_to_arabic(text):
108
+ # Load models only once
109
+ if 'tokenizer' not in translate_to_arabic.__dict__:
110
+ translate_to_arabic.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
111
+ translate_to_arabic.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
112
+ tokenizer = translate_to_arabic.tokenizer
113
+ model = translate_to_arabic.model
114
+
115
  tokenizer.src_lang = "en_XX"
116
  encoded = tokenizer(text, return_tensors="pt")
117
+ generated_tokens = model.generate(
118
  **encoded,
119
  forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"]
120
  )
121
+ translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
122
+ return translated_text
123
 
124
  # Gradio Interface Function (Combining Elements)
125
  def process_image(image):
126
+ # Ensure input is a PIL Image
127
  if isinstance(image, np.ndarray):
128
  image = Image.fromarray(image)
129
 
130
+ # Convert to RGB format for PIL processing
131
  image_rgb = image.convert("RGB")
132
+
133
+ # Load and resize the entire image
134
  resized_image_np = load_image(image_rgb)
135
+
136
+ # Convert resized image to PIL Image for Gradio output
137
  resized_image_pil = Image.fromarray(resized_image_np)
138
 
139
+ # Generate caption using BLIP model
140
  caption = generate_caption(image_rgb)
141
+
142
+ # Translate caption to Arabic
143
  caption_arabic = translate_to_arabic(caption)
144
 
145
+ # Extract dominant colors from the entire image
146
  colors = extract_colors(resized_image_np, k=8)
147
  color_palette = display_palette(colors)
148
+
149
+ # Create palette image
150
  palette_image = create_palette_image(colors)
151
 
152
+ # Combine English and Arabic captions
153
  bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
154
 
155
  return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
156
 
157
+ # Create Gradio Interface using Blocks and add a submit button
158
  with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
159
  gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>")
160
  gr.Markdown(
161
+ """
162
+ <p style='text-align: center;'>
163
+ Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.
164
+ </p>
165
+ """
166
  )
167
  with gr.Row():
168
  with gr.Column(scale=1):
169
  image_input = gr.Image(type="pil", label="Upload your image or select an example below")
170
  submit_button = gr.Button("Submit")
171
+ gr.Examples(
172
+ examples=examples,
173
+ inputs=image_input,
174
+ label="Example Images",
175
+ examples_per_page=5,
176
+ )
177
  with gr.Column(scale=1):
178
  caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10)
179
  palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2)
 
188
 
189
  # Launch Gradio Interface
190
  demo.launch()