File size: 5,626 Bytes
969899b
35d0b13
969899b
b9713ec
 
 
 
 
 
 
 
 
 
 
 
521a5df
 
 
 
 
 
 
b9713ec
 
 
 
 
 
 
 
 
 
 
 
521a5df
 
 
b9713ec
 
 
521a5df
 
b9713ec
 
 
 
 
 
 
 
521a5df
b9713ec
 
 
521a5df
b9713ec
 
521a5df
b9713ec
 
 
 
521a5df
b9713ec
 
 
521a5df
b9713ec
 
 
 
521a5df
b9713ec
 
 
 
521a5df
 
b9713ec
 
 
 
 
521a5df
b9713ec
 
 
521a5df
b9713ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521a5df
b9713ec
 
 
521a5df
b9713ec
 
 
 
 
521a5df
b9713ec
 
 
 
 
 
 
 
 
 
 
 
 
 
521a5df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Install necessary libraries
pip install torch transformers gradio Pillow scikit-learn requests

import numpy as np
import gradio as gr
from sklearn.cluster import KMeans
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    MBartForConditionalGeneration,
    MBart50TokenizerFast,
)
from PIL import Image, ImageDraw
import requests
from io import BytesIO
# Load Hugging Face models globally
print("Loading Hugging Face models...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

# Download example images
def download_example_images():
    image_urls = [
        ("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"),
        ("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"),
        ("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"),
        ("Beach and Ocean", "https://images.unsplash.com/photo-1507525428034-b723cf961d3e"),
        ("Desert Dunes", "https://images.unsplash.com/photo-1501594907352-04cda38ebc29"),
    ]
    
    example_images = []
    for idx, (description, url) in enumerate(image_urls, start=1):
        try:
            response = requests.get(url)
            response.raise_for_status()
            img = Image.open(BytesIO(response.content))
            img.save(f'example{idx}.jpg')
            example_images.append([f'example{idx}.jpg'])
        except requests.RequestException as e:
            print(f"Failed to download image from {url}: {e}")
    return example_images

# Download example images and prepare examples list
examples = download_example_images()

# Load and Process the Entire Image
def load_image(image):
    resized_image = image.resize((300, 300), resample=Image.LANCZOS)
    return np.array(resized_image)

# Extract Dominant Colors from the Image
def extract_colors(image, k=8):
    pixels = image.reshape(-1, 3) / 255.0
    kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300)
    kmeans.fit(pixels)
    return (kmeans.cluster_centers_ * 255).astype(int)

# Create an Image for the Color Palette
def create_palette_image(colors):
    num_colors = len(colors)
    palette_image = Image.new("RGB", (100 * num_colors, 100))
    draw = ImageDraw.Draw(palette_image)
    for i, color in enumerate(colors):
        color = tuple(np.clip(color, 0, 255).astype(int))
        draw.rectangle([i * 100, 0, (i + 1) * 100, 100], fill=color)
    return palette_image

# Display Color Palette as Hex Codes
def display_palette(colors):
    return ["#{:02x}{:02x}{:02x}".format(*np.clip(color, 0, 255)) for color in colors]

# Generate Image Caption Using Hugging Face BLIP
def generate_caption(image):
    inputs = processor(images=image, return_tensors="pt")
    output = caption_model.generate(**inputs)
    return processor.decode(output[0], skip_special_tokens=True)

# Translate Caption to Arabic Using mBART
def translate_to_arabic(text):
    tokenizer.src_lang = "en_XX"
    encoded = tokenizer(text, return_tensors="pt")
    generated_tokens = translation_model.generate(
        **encoded, 
        forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"]
    )
    return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

# Gradio Interface Function (Combining Elements)
def process_image(image):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    image_rgb = image.convert("RGB")
    resized_image_np = load_image(image_rgb)
    resized_image_pil = Image.fromarray(resized_image_np)
    
    caption = generate_caption(image_rgb)
    caption_arabic = translate_to_arabic(caption)
    
    colors = extract_colors(resized_image_np, k=8)
    color_palette = display_palette(colors)
    palette_image = create_palette_image(colors)
    
    bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
    
    return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil

# Create Gradio Interface
with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
    gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>")
    gr.Markdown(
        "<p style='text-align: center;'>Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.</p>"
    )
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload your image or select an example below")
            submit_button = gr.Button("Submit")
            gr.Examples(examples=examples, inputs=image_input, label="Example Images", examples_per_page=5)
        with gr.Column(scale=1):
            caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10)
            palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2)
            palette_image_output = gr.Image(type="pil", label="Color Palette")
            resized_image_output = gr.Image(type="pil", label="Resized Image")
    
    submit_button.click(
        fn=process_image,
        inputs=image_input,
        outputs=[caption_output, palette_hex_output, palette_image_output, resized_image_output],
    )

# Launch Gradio Interface
demo.launch()