|
|
|
|
|
|
|
import numpy as np |
|
import gradio as gr |
|
from sklearn.cluster import KMeans |
|
from transformers import pipeline |
|
from PIL import Image, ImageDraw |
|
import requests |
|
from io import BytesIO |
|
|
|
|
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
print("Loading pipelines...") |
|
|
|
|
|
|
|
caption_pipeline = pipeline( |
|
"image-to-text", |
|
model="Salesforce/blip-image-captioning-base" |
|
) |
|
|
|
|
|
|
|
|
|
translation_pipeline = pipeline( |
|
"translation", |
|
model="facebook/mbart-large-50-many-to-many-mmt", |
|
tokenizer="facebook/mbart-large-50-many-to-many-mmt", |
|
src_lang="en_XX", |
|
tgt_lang="ar_AR" |
|
) |
|
|
|
print("Pipelines loaded successfully.") |
|
|
|
|
|
image_examples = [ |
|
["https://images.unsplash.com/photo-1501785888041-af3ef285b470?w=512"], |
|
["https://images.unsplash.com/photo-1502082553048-f009c37129b9?w=512"], |
|
["https://images.unsplash.com/photo-1507525428034-b723cf961d3e?w=512"], |
|
["https://images.unsplash.com/photo-1501594907352-04cda38ebc29?w=512"], |
|
["https://images.unsplash.com/photo-1519608487953-e999c86e7455?w=512"], |
|
["https://images.unsplash.com/photo-1500530855697-b586d89ba3ee?w=512"], |
|
["https://images.unsplash.com/photo-1512453979798-5ea266f8880c?w=512"], |
|
["https://images.unsplash.com/photo-1506744038136-46273834b3fb?w=512"], |
|
] |
|
|
|
|
|
def load_image(image): |
|
""" |
|
Converts the input image to a numpy array and resizes it. |
|
|
|
Args: |
|
image (PIL.Image.Image): The input image. |
|
|
|
Returns: |
|
resized_image_np (numpy.ndarray): The resized image as a numpy array. |
|
""" |
|
|
|
image_np = np.array(image.convert('RGB')) |
|
|
|
|
|
resized_image = image.resize((300, 300), resample=Image.LANCZOS) |
|
resized_image_np = np.array(resized_image) |
|
|
|
return resized_image_np |
|
|
|
|
|
def extract_colors(image, k=8): |
|
""" |
|
Uses KMeans clustering to extract dominant colors from the image. |
|
|
|
Args: |
|
image (numpy.ndarray): The input image as a numpy array. |
|
k (int): The number of clusters (colors) to extract. |
|
|
|
Returns: |
|
colors (numpy.ndarray): An array of the dominant colors. |
|
""" |
|
|
|
pixels = image.reshape(-1, 3) |
|
|
|
|
|
pixels = pixels / 255.0 |
|
pixels = pixels.astype(np.float64) |
|
|
|
|
|
kmeans = KMeans( |
|
n_clusters=k, |
|
random_state=0, |
|
n_init=10, |
|
max_iter=300 |
|
) |
|
kmeans.fit(pixels) |
|
|
|
|
|
colors = (kmeans.cluster_centers_ * 255).astype(int) |
|
return colors |
|
|
|
|
|
def create_palette_image(colors): |
|
""" |
|
Creates a visual representation of the color palette. |
|
|
|
Args: |
|
colors (numpy.ndarray): An array of the dominant colors. |
|
|
|
Returns: |
|
palette_image (PIL.Image.Image): The generated color palette image. |
|
""" |
|
num_colors = len(colors) |
|
palette_height = 100 |
|
palette_width = 100 * num_colors |
|
palette_image = Image.new( |
|
"RGB", |
|
(palette_width, palette_height) |
|
) |
|
|
|
draw = ImageDraw.Draw(palette_image) |
|
for i, color in enumerate(colors): |
|
|
|
color = tuple(np.clip(color, 0, 255).astype(int)) |
|
|
|
draw.rectangle( |
|
[i * 100, 0, (i + 1) * 100, palette_height], |
|
fill=color |
|
) |
|
|
|
return palette_image |
|
|
|
|
|
def display_palette(colors): |
|
""" |
|
Converts RGB colors to hexadecimal format. |
|
|
|
Args: |
|
colors (numpy.ndarray): An array of the dominant colors. |
|
|
|
Returns: |
|
hex_colors (list): A list of hex color codes. |
|
""" |
|
hex_colors = [] |
|
for color in colors: |
|
|
|
color = np.clip(color, 0, 255).astype(int) |
|
|
|
hex_color = "#{:02x}{:02x}{:02x}".format( |
|
color[0], |
|
color[1], |
|
color[2] |
|
) |
|
hex_colors.append(hex_color) |
|
return hex_colors |
|
|
|
|
|
def generate_caption(image): |
|
""" |
|
Generates a caption for the input image using a pre-trained model. |
|
|
|
Args: |
|
image (PIL.Image.Image): The input image. |
|
|
|
Returns: |
|
caption (str): The generated caption. |
|
""" |
|
|
|
result = caption_pipeline(image) |
|
caption = result[0]['generated_text'] |
|
return caption |
|
|
|
|
|
def translate_to_arabic(text): |
|
""" |
|
Translates English text to Arabic using a pre-trained model with enhanced post-processing. |
|
|
|
Args: |
|
text (str): The English text to translate. |
|
|
|
Returns: |
|
translated_text (str): The translated Arabic text. |
|
""" |
|
try: |
|
|
|
result = translation_pipeline(text) |
|
translated_text = result[0]['translation_text'] |
|
|
|
|
|
words = translated_text.split() |
|
seen = set() |
|
cleaned_words = [] |
|
previous_word = "" |
|
for word in words: |
|
if word != previous_word: |
|
cleaned_words.append(word) |
|
seen.add(word) |
|
previous_word = word |
|
cleaned_translated_text = ' '.join(cleaned_words) |
|
|
|
return cleaned_translated_text |
|
except Exception as e: |
|
print(f"Error during translation: {e}") |
|
return "Translation Error" |
|
|
|
|
|
def process_image(image): |
|
""" |
|
Processes the input image to generate a bilingual caption and color palette. |
|
|
|
Args: |
|
image (PIL.Image.Image or numpy.ndarray): The input image. |
|
|
|
Returns: |
|
tuple: Contains bilingual caption, hex color codes, palette image, and resized image. |
|
""" |
|
|
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
|
|
image_rgb = image.convert("RGB") |
|
|
|
|
|
resized_image_np = load_image(image_rgb) |
|
resized_image_pil = Image.fromarray(resized_image_np) |
|
|
|
|
|
caption = generate_caption(image_rgb) |
|
|
|
|
|
caption_arabic = translate_to_arabic(caption) |
|
|
|
|
|
colors = extract_colors(resized_image_np, k=8) |
|
color_palette = display_palette(colors) |
|
|
|
|
|
palette_image = create_palette_image(colors) |
|
|
|
|
|
bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}" |
|
|
|
return ( |
|
bilingual_caption, |
|
", ".join(color_palette), |
|
palette_image, |
|
resized_image_pil |
|
) |
|
|
|
|
|
with gr.Blocks( |
|
css=".gradio-container { height: 1000px !important; }" |
|
) as demo: |
|
|
|
gr.Markdown( |
|
"<h1 style='text-align: center;'>" |
|
"Palette Generator from Image with Image Captioning" |
|
"</h1>" |
|
) |
|
gr.Markdown( |
|
""" |
|
<p style='text-align: center;'> |
|
Upload an image or select one of the example images below to generate |
|
a color palette and a description of the image in both English and Arabic. |
|
</p> |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
image_input = gr.Image( |
|
type="pil", |
|
label="Upload your image or select an example below" |
|
) |
|
|
|
submit_button = gr.Button("Submit") |
|
|
|
gr.Examples( |
|
examples=image_examples, |
|
inputs=image_input, |
|
label="Example Images", |
|
examples_per_page=10, |
|
fn=None, |
|
) |
|
with gr.Column(scale=1): |
|
|
|
caption_output = gr.Textbox( |
|
label="Bilingual Caption", |
|
lines=5, |
|
max_lines=10 |
|
) |
|
palette_hex_output = gr.Textbox( |
|
label="Color Palette Hex Codes", |
|
lines=2 |
|
) |
|
palette_image_output = gr.Image( |
|
type="pil", |
|
label="Color Palette" |
|
) |
|
resized_image_output = gr.Image( |
|
type="pil", |
|
label="Resized Image" |
|
) |
|
|
|
|
|
submit_button.click( |
|
fn=process_image, |
|
inputs=image_input, |
|
outputs=[ |
|
caption_output, |
|
palette_hex_output, |
|
palette_image_output, |
|
resized_image_output |
|
], |
|
) |
|
|
|
|
|
demo.launch() |