ayajoharji's picture
Update app.py
7df98bf verified
raw
history blame
7.32 kB
# app.py
import numpy as np
import gradio as gr
from sklearn.cluster import KMeans
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
MBartForConditionalGeneration,
MBart50TokenizerFast,
)
from PIL import Image, ImageDraw
import requests
from io import BytesIO
# Download example images
def download_example_images():
image_urls = [
# URL format: ("Image Description", "Image URL")
("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"),
("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"),
("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"),
("Beach and Ocean", "https://images.unsplash.com/photo-1507525428034-b723cf961d3e"),
("Desert Dunes", "https://images.unsplash.com/photo-1501594907352-04cda38ebc29"),
]
example_images = []
for idx, (description, url) in enumerate(image_urls, start=1):
response = requests.get(url)
if response.status_code == 200:
img = Image.open(BytesIO(response.content))
img.save(f'example{idx}.jpg')
example_images.append([f'example{idx}.jpg'])
else:
print(f"Failed to download image from {url}")
return example_images
# Download example images and prepare examples list
examples = download_example_images()
# Load and Process the Entire Image
def load_image(image):
# Convert PIL image to numpy array (RGB)
image_np = np.array(image.convert('RGB'))
# Resize the image for better processing
resized_image = image.resize((300, 300), resample=Image.LANCZOS)
resized_image_np = np.array(resized_image)
return resized_image_np
# Extract Dominant Colors from the Image
def extract_colors(image, k=8):
# Flatten the image
pixels = image.reshape(-1, 3)
# Normalize pixel values to [0, 1]
pixels = pixels / 255.0
# Ensure data type is float64
pixels = pixels.astype(np.float64)
# Apply K-means clustering to find dominant colors
kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300)
kmeans.fit(pixels)
# Convert normalized colors back to 0-255 scale
colors = (kmeans.cluster_centers_ * 255).astype(int)
return colors
# Create an Image for the Color Palette
def create_palette_image(colors):
num_colors = len(colors)
palette_height = 100
palette_width = 100 * num_colors
palette_image = Image.new("RGB", (palette_width, palette_height))
draw = ImageDraw.Draw(palette_image)
for i, color in enumerate(colors):
# Ensure color values are within the valid range and integers
color = tuple(np.clip(color, 0, 255).astype(int))
draw.rectangle([i * 100, 0, (i + 1) * 100, palette_height], fill=color)
return palette_image
# Display Color Palette as Hex Codes
def display_palette(colors):
hex_colors = []
for color in colors:
# Ensure color values are within valid range and integers
color = np.clip(color, 0, 255).astype(int)
hex_color = "#{:02x}{:02x}{:02x}".format(color[0], color[1], color[2])
hex_colors.append(hex_color)
return hex_colors
# Generate Image Caption Using Hugging Face BLIP
def generate_caption(image):
# Load models only once
if 'processor' not in generate_caption.__dict__:
generate_caption.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
generate_caption.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
processor = generate_caption.processor
model = generate_caption.model
inputs = processor(images=image, return_tensors="pt")
output = model.generate(**inputs)
caption = processor.decode(output[0], skip_special_tokens=True)
return caption
# Translate Caption to Arabic Using mBART
def translate_to_arabic(text):
# Load models only once
if 'tokenizer' not in translate_to_arabic.__dict__:
translate_to_arabic.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
translate_to_arabic.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = translate_to_arabic.tokenizer
model = translate_to_arabic.model
tokenizer.src_lang = "en_XX"
encoded = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"]
)
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translated_text
# Gradio Interface Function (Combining Elements)
def process_image(image):
# Ensure input is a PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Convert to RGB format for PIL processing
image_rgb = image.convert("RGB")
# Load and resize the entire image
resized_image_np = load_image(image_rgb)
# Convert resized image to PIL Image for Gradio output
resized_image_pil = Image.fromarray(resized_image_np)
# Generate caption using BLIP model
caption = generate_caption(image_rgb)
# Translate caption to Arabic
caption_arabic = translate_to_arabic(caption)
# Extract dominant colors from the entire image
colors = extract_colors(resized_image_np, k=8)
color_palette = display_palette(colors)
# Create palette image
palette_image = create_palette_image(colors)
# Combine English and Arabic captions
bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
# Create Gradio Interface using Blocks and add a submit button
with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>")
gr.Markdown(
"""
<p style='text-align: center;'>
Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.
</p>
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload your image or select an example below")
submit_button = gr.Button("Submit")
gr.Examples(
examples=examples,
inputs=image_input,
label="Example Images",
examples_per_page=5,
)
with gr.Column(scale=1):
caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10)
palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2)
palette_image_output = gr.Image(type="pil", label="Color Palette")
resized_image_output = gr.Image(type="pil", label="Resized Image")
submit_button.click(
fn=process_image,
inputs=image_input,
outputs=[caption_output, palette_hex_output, palette_image_output, resized_image_output],
)
# Launch Gradio Interface
demo.launch()