ayajoharji commited on
Commit
b9713ec
·
verified ·
1 Parent(s): c4b3015

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from sklearn.cluster import KMeans
4
+ from transformers import (
5
+ BlipProcessor,
6
+ BlipForConditionalGeneration,
7
+ MBartForConditionalGeneration,
8
+ MBart50TokenizerFast,
9
+ )
10
+ from PIL import Image, ImageDraw
11
+ import requests
12
+ from io import BytesIO
13
+
14
+ # Download example images
15
+ def download_example_images():
16
+ image_urls = [
17
+ # URL format: ("Image Description", "Image URL")
18
+ ("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"),
19
+ ("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"),
20
+ ("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"),
21
+ ("Beach and Ocean", "https://images.unsplash.com/photo-1507525428034-b723cf961d3e"),
22
+ ("Desert Dunes", "https://images.unsplash.com/photo-1501594907352-04cda38ebc29"),
23
+ ]
24
+
25
+ example_images = []
26
+ for idx, (description, url) in enumerate(image_urls, start=1):
27
+ response = requests.get(url)
28
+ if response.status_code == 200:
29
+ img = Image.open(BytesIO(response.content))
30
+ img.save(f'example{idx}.jpg')
31
+ example_images.append([f'example{idx}.jpg'])
32
+ else:
33
+ print(f"Failed to download image from {url}")
34
+ return example_images
35
+
36
+ # Download example images and prepare examples list
37
+ examples = download_example_images()
38
+
39
+ # Load and Process the Entire Image
40
+ def load_image(image):
41
+ # Convert PIL image to numpy array (RGB)
42
+ image_np = np.array(image.convert('RGB'))
43
+
44
+ # Resize the image for better processing
45
+ resized_image = image.resize((300, 300), resample=Image.LANCZOS)
46
+ resized_image_np = np.array(resized_image)
47
+
48
+ return resized_image_np
49
+
50
+ # Extract Dominant Colors from the Image
51
+ def extract_colors(image, k=8):
52
+ # Flatten the image
53
+ pixels = image.reshape(-1, 3)
54
+ # Normalize pixel values to [0, 1]
55
+ pixels = pixels / 255.0
56
+ # Ensure data type is float64
57
+ pixels = pixels.astype(np.float64)
58
+ # Apply K-means clustering to find dominant colors
59
+ kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300)
60
+ kmeans.fit(pixels)
61
+ # Convert normalized colors back to 0-255 scale
62
+ colors = (kmeans.cluster_centers_ * 255).astype(int)
63
+ return colors
64
+
65
+ # Create an Image for the Color Palette
66
+ def create_palette_image(colors):
67
+ num_colors = len(colors)
68
+ palette_height = 100
69
+ palette_width = 100 * num_colors
70
+ palette_image = Image.new("RGB", (palette_width, palette_height))
71
+
72
+ draw = ImageDraw.Draw(palette_image)
73
+ for i, color in enumerate(colors):
74
+ # Ensure color values are within the valid range and integers
75
+ color = tuple(np.clip(color, 0, 255).astype(int))
76
+ draw.rectangle([i * 100, 0, (i + 1) * 100, palette_height], fill=color)
77
+
78
+ return palette_image
79
+
80
+ # Display Color Palette as Hex Codes
81
+ def display_palette(colors):
82
+ hex_colors = []
83
+ for color in colors:
84
+ # Ensure color values are within valid range and integers
85
+ color = np.clip(color, 0, 255).astype(int)
86
+ hex_color = "#{:02x}{:02x}{:02x}".format(color[0], color[1], color[2])
87
+ hex_colors.append(hex_color)
88
+ return hex_colors
89
+
90
+ # Generate Image Caption Using Hugging Face BLIP
91
+ def generate_caption(image):
92
+ # Load models only once
93
+ if 'processor' not in generate_caption.__dict__:
94
+ generate_caption.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
95
+ generate_caption.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
96
+ processor = generate_caption.processor
97
+ model = generate_caption.model
98
+
99
+ inputs = processor(images=image, return_tensors="pt")
100
+ output = model.generate(**inputs)
101
+ caption = processor.decode(output[0], skip_special_tokens=True)
102
+ return caption
103
+
104
+ # Translate Caption to Arabic Using mBART
105
+ def translate_to_arabic(text):
106
+ # Load models only once
107
+ if 'tokenizer' not in translate_to_arabic.__dict__:
108
+ translate_to_arabic.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
109
+ translate_to_arabic.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
110
+ tokenizer = translate_to_arabic.tokenizer
111
+ model = translate_to_arabic.model
112
+
113
+ tokenizer.src_lang = "en_XX"
114
+ encoded = tokenizer(text, return_tensors="pt")
115
+ generated_tokens = model.generate(
116
+ **encoded,
117
+ forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"]
118
+ )
119
+ translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
120
+ return translated_text
121
+
122
+ # Gradio Interface Function (Combining Elements)
123
+ def process_image(image):
124
+ # Ensure input is a PIL Image
125
+ if isinstance(image, np.ndarray):
126
+ image = Image.fromarray(image)
127
+
128
+ # Convert to RGB format for PIL processing
129
+ image_rgb = image.convert("RGB")
130
+
131
+ # Load and resize the entire image
132
+ resized_image_np = load_image(image_rgb)
133
+
134
+ # Convert resized image to PIL Image for Gradio output
135
+ resized_image_pil = Image.fromarray(resized_image_np)
136
+
137
+ # Generate caption using BLIP model
138
+ caption = generate_caption(image_rgb)
139
+
140
+ # Translate caption to Arabic
141
+ caption_arabic = translate_to_arabic(caption)
142
+
143
+ # Extract dominant colors from the entire image
144
+ colors = extract_colors(resized_image_np, k=8)
145
+ color_palette = display_palette(colors)
146
+
147
+ # Create palette image
148
+ palette_image = create_palette_image(colors)
149
+
150
+ # Combine English and Arabic captions
151
+ bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
152
+
153
+ return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
154
+
155
+ # Create Gradio Interface using Blocks and add a submit button
156
+ with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
157
+ gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>")
158
+ gr.Markdown(
159
+ """
160
+ <p style='text-align: center;'>
161
+ Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.
162
+ </p>
163
+ """
164
+ )
165
+ with gr.Row():
166
+ with gr.Column(scale=1):
167
+ image_input = gr.Image(type="pil", label="Upload your image or select an example below")
168
+ submit_button = gr.Button("Submit")
169
+ gr.Examples(
170
+ examples=examples,
171
+ inputs=image_input,
172
+ label="Example Images",
173
+ examples_per_page=5,
174
+ )
175
+ with gr.Column(scale=1):
176
+ caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10)
177
+ palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2)
178
+ palette_image_output = gr.Image(type="pil", label="Color Palette")
179
+ resized_image_output = gr.Image(type="pil", label="Resized Image")
180
+
181
+ submit_button.click(
182
+ fn=process_image,
183
+ inputs=image_input,
184
+ outputs=[caption_output, palette_hex_output, palette_image_output, resized_image_output],
185
+ )
186
+
187
+ # Launch Gradio Interface
188
+ demo.launch()