|
|
|
|
|
import numpy as np |
|
import gradio as gr |
|
from sklearn.cluster import KMeans |
|
from transformers import ( |
|
BlipProcessor, |
|
BlipForConditionalGeneration, |
|
MBartForConditionalGeneration, |
|
MBart50TokenizerFast, |
|
) |
|
from PIL import Image, ImageDraw |
|
import requests |
|
from io import BytesIO |
|
|
|
|
|
def download_example_images(): |
|
image_urls = [ |
|
|
|
("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"), |
|
("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"), |
|
("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"), |
|
("Beach and Ocean", "https://images.unsplash.com/photo-1507525428034-b723cf961d3e"), |
|
("Desert Dunes", "https://images.unsplash.com/photo-1501594907352-04cda38ebc29"), |
|
] |
|
|
|
example_images = [] |
|
for idx, (description, url) in enumerate(image_urls, start=1): |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
img = Image.open(BytesIO(response.content)) |
|
img.save(f'example{idx}.jpg') |
|
example_images.append([f'example{idx}.jpg']) |
|
else: |
|
print(f"Failed to download image from {url}") |
|
return example_images |
|
|
|
|
|
examples = download_example_images() |
|
|
|
|
|
def load_image(image): |
|
|
|
image_np = np.array(image.convert('RGB')) |
|
|
|
|
|
resized_image = image.resize((300, 300), resample=Image.LANCZOS) |
|
resized_image_np = np.array(resized_image) |
|
|
|
return resized_image_np |
|
|
|
|
|
def extract_colors(image, k=8): |
|
|
|
pixels = image.reshape(-1, 3) |
|
|
|
pixels = pixels / 255.0 |
|
|
|
pixels = pixels.astype(np.float64) |
|
|
|
kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300) |
|
kmeans.fit(pixels) |
|
|
|
colors = (kmeans.cluster_centers_ * 255).astype(int) |
|
return colors |
|
|
|
|
|
def create_palette_image(colors): |
|
num_colors = len(colors) |
|
palette_height = 100 |
|
palette_width = 100 * num_colors |
|
palette_image = Image.new("RGB", (palette_width, palette_height)) |
|
|
|
draw = ImageDraw.Draw(palette_image) |
|
for i, color in enumerate(colors): |
|
|
|
color = tuple(np.clip(color, 0, 255).astype(int)) |
|
draw.rectangle([i * 100, 0, (i + 1) * 100, palette_height], fill=color) |
|
|
|
return palette_image |
|
|
|
|
|
def display_palette(colors): |
|
hex_colors = [] |
|
for color in colors: |
|
|
|
color = np.clip(color, 0, 255).astype(int) |
|
hex_color = "#{:02x}{:02x}{:02x}".format(color[0], color[1], color[2]) |
|
hex_colors.append(hex_color) |
|
return hex_colors |
|
|
|
|
|
def generate_caption(image): |
|
|
|
if 'processor' not in generate_caption.__dict__: |
|
generate_caption.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
generate_caption.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
processor = generate_caption.processor |
|
model = generate_caption.model |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
output = model.generate(**inputs) |
|
caption = processor.decode(output[0], skip_special_tokens=True) |
|
return caption |
|
|
|
|
|
def translate_to_arabic(text): |
|
|
|
if 'tokenizer' not in translate_to_arabic.__dict__: |
|
translate_to_arabic.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
translate_to_arabic.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
tokenizer = translate_to_arabic.tokenizer |
|
model = translate_to_arabic.model |
|
|
|
tokenizer.src_lang = "en_XX" |
|
encoded = tokenizer(text, return_tensors="pt") |
|
generated_tokens = model.generate( |
|
**encoded, |
|
forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"] |
|
) |
|
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
return translated_text |
|
|
|
|
|
def process_image(image): |
|
|
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
|
|
image_rgb = image.convert("RGB") |
|
|
|
|
|
resized_image_np = load_image(image_rgb) |
|
|
|
|
|
resized_image_pil = Image.fromarray(resized_image_np) |
|
|
|
|
|
caption = generate_caption(image_rgb) |
|
|
|
|
|
caption_arabic = translate_to_arabic(caption) |
|
|
|
|
|
colors = extract_colors(resized_image_np, k=8) |
|
color_palette = display_palette(colors) |
|
|
|
|
|
palette_image = create_palette_image(colors) |
|
|
|
|
|
bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}" |
|
|
|
return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil |
|
|
|
|
|
with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo: |
|
gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>") |
|
gr.Markdown( |
|
""" |
|
<p style='text-align: center;'> |
|
Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic. |
|
</p> |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image(type="pil", label="Upload your image or select an example below") |
|
submit_button = gr.Button("Submit") |
|
gr.Examples( |
|
examples=examples, |
|
inputs=image_input, |
|
label="Example Images", |
|
examples_per_page=5, |
|
) |
|
with gr.Column(scale=1): |
|
caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10) |
|
palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2) |
|
palette_image_output = gr.Image(type="pil", label="Color Palette") |
|
resized_image_output = gr.Image(type="pil", label="Resized Image") |
|
|
|
submit_button.click( |
|
fn=process_image, |
|
inputs=image_input, |
|
outputs=[caption_output, palette_hex_output, palette_image_output, resized_image_output], |
|
) |
|
|
|
|
|
demo.launch() |
|
|