File size: 9,939 Bytes
9685466
 
e90da55
b9713ec
 
 
64c5d24
b9713ec
 
 
521a5df
64c5d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e90da55
64c5d24
 
 
 
 
 
 
 
 
 
f148e92
9685466
 
 
 
 
 
 
 
 
 
 
b9713ec
64c5d24
b9713ec
64c5d24
 
 
 
 
 
 
 
 
7df98bf
 
 
64c5d24
 
7df98bf
 
 
b9713ec
64c5d24
b9713ec
64c5d24
 
 
 
 
 
 
 
 
 
 
7df98bf
64c5d24
7df98bf
 
 
64c5d24
 
 
 
 
 
 
 
b9713ec
64c5d24
7df98bf
 
 
b9713ec
64c5d24
b9713ec
64c5d24
 
 
 
 
 
 
 
 
b9713ec
64c5d24
 
 
 
 
 
7df98bf
b9713ec
 
64c5d24
b9713ec
64c5d24
 
 
 
 
7df98bf
b9713ec
 
64c5d24
b9713ec
64c5d24
 
 
 
 
 
 
 
 
7df98bf
 
 
 
64c5d24
 
 
 
 
 
7df98bf
 
b9713ec
64c5d24
b9713ec
64c5d24
 
 
 
 
 
 
 
 
 
 
 
7df98bf
b9713ec
64c5d24
b9713ec
64c5d24
 
 
 
 
 
 
 
 
f148e92
64c5d24
 
 
f148e92
e90da55
64c5d24
e90da55
64c5d24
 
 
 
 
e90da55
 
64c5d24
f148e92
64c5d24
f148e92
64c5d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9713ec
7df98bf
64c5d24
 
 
 
 
 
 
 
 
b9713ec
7df98bf
 
64c5d24
 
7df98bf
 
b9713ec
 
 
64c5d24
 
 
 
 
 
b9713ec
9685466
7df98bf
9685466
7df98bf
 
9685466
 
7df98bf
b9713ec
64c5d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9713ec
64c5d24
b9713ec
 
 
64c5d24
 
 
 
 
 
b9713ec
 
 
9685466
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
# app.py

# Import Libraries
import numpy as np
import gradio as gr
from sklearn.cluster import KMeans
from transformers import pipeline
from PIL import Image, ImageDraw
import requests
from io import BytesIO

# Suppress warnings for cleaner output
import warnings
warnings.filterwarnings('ignore')

# Load pipelines globally to avoid reloading on each inference
print("Loading pipelines...")

# Image Captioning Pipeline
# Using Salesforce/blip-image-captioning-base for generating image captions
caption_pipeline = pipeline(
    "image-to-text",
    model="Salesforce/blip-image-captioning-base"
)

# Translation Pipeline
# Using facebook/mbart-large-50-many-to-many-mmt for translations
# This model supports multiple languages and provides better translation quality for Arabic
translation_pipeline = pipeline(
    "translation",
    model="facebook/mbart-large-50-many-to-many-mmt",
    tokenizer="facebook/mbart-large-50-many-to-many-mmt",
    src_lang="en_XX",
    tgt_lang="ar_AR"
)

print("Pipelines loaded successfully.")

# Define a list of image URLs for examples
image_examples = [
    ["https://images.unsplash.com/photo-1501785888041-af3ef285b470?w=512"],
    ["https://images.unsplash.com/photo-1502082553048-f009c37129b9?w=512"],
    ["https://images.unsplash.com/photo-1507525428034-b723cf961d3e?w=512"],
    ["https://images.unsplash.com/photo-1501594907352-04cda38ebc29?w=512"],
    ["https://images.unsplash.com/photo-1519608487953-e999c86e7455?w=512"],
    ["https://images.unsplash.com/photo-1500530855697-b586d89ba3ee?w=512"],
    ["https://images.unsplash.com/photo-1512453979798-5ea266f8880c?w=512"],
    ["https://images.unsplash.com/photo-1506744038136-46273834b3fb?w=512"],
]

# Function to Load and Process Image
def load_image(image):
    """
    Converts the input image to a numpy array and resizes it.
    
    Args:
        image (PIL.Image.Image): The input image.
    
    Returns:
        resized_image_np (numpy.ndarray): The resized image as a numpy array.
    """
    # Convert PIL image to numpy array (RGB)
    image_np = np.array(image.convert('RGB'))
    
    # Resize the image to (300, 300) for consistent processing
    resized_image = image.resize((300, 300), resample=Image.LANCZOS)
    resized_image_np = np.array(resized_image)
    
    return resized_image_np

# Function to Extract Dominant Colors from the Image
def extract_colors(image, k=8):
    """
    Uses KMeans clustering to extract dominant colors from the image.
    
    Args:
        image (numpy.ndarray): The input image as a numpy array.
        k (int): The number of clusters (colors) to extract.
    
    Returns:
        colors (numpy.ndarray): An array of the dominant colors.
    """
    # Flatten the image to a 2D array of pixels
    pixels = image.reshape(-1, 3)
    
    # Normalize pixel values to [0, 1]
    pixels = pixels / 255.0
    pixels = pixels.astype(np.float64)
    
    # Apply KMeans clustering to find dominant colors
    kmeans = KMeans(
        n_clusters=k,
        random_state=0,
        n_init=10,
        max_iter=300
    )
    kmeans.fit(pixels)
    
    # Convert normalized colors back to 0-255 scale
    colors = (kmeans.cluster_centers_ * 255).astype(int)
    return colors

# Function to Create an Image for the Color Palette
def create_palette_image(colors):
    """
    Creates a visual representation of the color palette.
    
    Args:
        colors (numpy.ndarray): An array of the dominant colors.
    
    Returns:
        palette_image (PIL.Image.Image): The generated color palette image.
    """
    num_colors = len(colors)
    palette_height = 100
    palette_width = 100 * num_colors
    palette_image = Image.new(
        "RGB",
        (palette_width, palette_height)
    )
    
    draw = ImageDraw.Draw(palette_image)
    for i, color in enumerate(colors):
        # Ensure color values are within valid range and integers
        color = tuple(np.clip(color, 0, 255).astype(int))
        # Draw rectangles for each color
        draw.rectangle(
            [i * 100, 0, (i + 1) * 100, palette_height],
            fill=color
        )
    
    return palette_image

# Function to Display Color Palette as Hex Codes
def display_palette(colors):
    """
    Converts RGB colors to hexadecimal format.
    
    Args:
        colors (numpy.ndarray): An array of the dominant colors.
    
    Returns:
        hex_colors (list): A list of hex color codes.
    """
    hex_colors = []
    for color in colors:
        # Ensure color values are within valid range and integers
        color = np.clip(color, 0, 255).astype(int)
        # Convert to hex code
        hex_color = "#{:02x}{:02x}{:02x}".format(
            color[0],
            color[1],
            color[2]
        )
        hex_colors.append(hex_color)
    return hex_colors

# Function to Generate Image Caption Using Pipeline
def generate_caption(image):
    """
    Generates a caption for the input image using a pre-trained model.
    
    Args:
        image (PIL.Image.Image): The input image.
    
    Returns:
        caption (str): The generated caption.
    """
    # Use the captioning pipeline to generate a caption
    result = caption_pipeline(image)
    caption = result[0]['generated_text']
    return caption

# Function to Translate Caption to Arabic Using Pipeline
def translate_to_arabic(text):
    """
    Translates English text to Arabic using a pre-trained model with enhanced post-processing.
    
    Args:
        text (str): The English text to translate.
    
    Returns:
        translated_text (str): The translated Arabic text.
    """
    try:
        # Use the translation pipeline to translate the text
        result = translation_pipeline(text)
        translated_text = result[0]['translation_text']
        
        # Post-processing to remove repeated words
        words = translated_text.split()
        seen = set()
        cleaned_words = []
        previous_word = ""
        for word in words:
            if word != previous_word:
                cleaned_words.append(word)
                seen.add(word)
            previous_word = word
        cleaned_translated_text = ' '.join(cleaned_words)
        
        return cleaned_translated_text
    except Exception as e:
        print(f"Error during translation: {e}")
        return "Translation Error"

# Gradio Interface Function (Combining All Elements)
def process_image(image):
    """
    Processes the input image to generate a bilingual caption and color palette.
    
    Args:
        image (PIL.Image.Image or numpy.ndarray): The input image.
    
    Returns:
        tuple: Contains bilingual caption, hex color codes, palette image, and resized image.
    """
    # Ensure input is a PIL Image
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    # Convert to RGB format
    image_rgb = image.convert("RGB")
    
    # Load and resize the image
    resized_image_np = load_image(image_rgb)
    resized_image_pil = Image.fromarray(resized_image_np)
    
    # Generate caption using the caption pipeline
    caption = generate_caption(image_rgb)
    
    # Translate caption to Arabic using the translation pipeline
    caption_arabic = translate_to_arabic(caption)
    
    # Extract dominant colors from the image
    colors = extract_colors(resized_image_np, k=8)
    color_palette = display_palette(colors)
    
    # Create palette image
    palette_image = create_palette_image(colors)
    
    # Combine English and Arabic captions
    bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
    
    return (
        bilingual_caption,
        ", ".join(color_palette),
        palette_image,
        resized_image_pil
    )

# Create Gradio Interface using Blocks and add a submit button
with gr.Blocks(
    css=".gradio-container { height: 1000px !important; }"
) as demo:
    # Title and Description
    gr.Markdown(
        "<h1 style='text-align: center;'>"
        "Palette Generator from Image with Image Captioning"
        "</h1>"
    )
    gr.Markdown(
        """
        <p style='text-align: center;'>
        Upload an image or select one of the example images below to generate
        a color palette and a description of the image in both English and Arabic.
        </p>
        """
    )
    with gr.Row():
        with gr.Column(scale=1):
            # Image Input Component
            image_input = gr.Image(
                type="pil",
                label="Upload your image or select an example below"
            )
            # Submit Button
            submit_button = gr.Button("Submit")
            # Examples Component using Image URLs directly
            gr.Examples(
                examples=image_examples,  # List of lists with image URLs
                inputs=image_input,
                label="Example Images",
                examples_per_page=10,  # Adjust as needed
                fn=None,  # No need to specify a function since we're using URLs
            )
        with gr.Column(scale=1):
            # Output Components
            caption_output = gr.Textbox(
                label="Bilingual Caption",
                lines=5,
                max_lines=10
            )
            palette_hex_output = gr.Textbox(
                label="Color Palette Hex Codes",
                lines=2
            )
            palette_image_output = gr.Image(
                type="pil",
                label="Color Palette"
            )
            resized_image_output = gr.Image(
                type="pil",
                label="Resized Image"
            )
    
    # Define the action on submit button click
    submit_button.click(
        fn=process_image,
        inputs=image_input,
        outputs=[
            caption_output,
            palette_hex_output,
            palette_image_output,
            resized_image_output
        ],
    )

# Launch Gradio Interface
demo.launch()