ayajoharji's picture
Update app.py
35d0b13 verified
raw
history blame
5.63 kB
# Install necessary libraries
pip install torch transformers gradio Pillow scikit-learn requests
import numpy as np
import gradio as gr
from sklearn.cluster import KMeans
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
MBartForConditionalGeneration,
MBart50TokenizerFast,
)
from PIL import Image, ImageDraw
import requests
from io import BytesIO
# Load Hugging Face models globally
print("Loading Hugging Face models...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
# Download example images
def download_example_images():
image_urls = [
("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"),
("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"),
("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"),
("Beach and Ocean", "https://images.unsplash.com/photo-1507525428034-b723cf961d3e"),
("Desert Dunes", "https://images.unsplash.com/photo-1501594907352-04cda38ebc29"),
]
example_images = []
for idx, (description, url) in enumerate(image_urls, start=1):
try:
response = requests.get(url)
response.raise_for_status()
img = Image.open(BytesIO(response.content))
img.save(f'example{idx}.jpg')
example_images.append([f'example{idx}.jpg'])
except requests.RequestException as e:
print(f"Failed to download image from {url}: {e}")
return example_images
# Download example images and prepare examples list
examples = download_example_images()
# Load and Process the Entire Image
def load_image(image):
resized_image = image.resize((300, 300), resample=Image.LANCZOS)
return np.array(resized_image)
# Extract Dominant Colors from the Image
def extract_colors(image, k=8):
pixels = image.reshape(-1, 3) / 255.0
kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300)
kmeans.fit(pixels)
return (kmeans.cluster_centers_ * 255).astype(int)
# Create an Image for the Color Palette
def create_palette_image(colors):
num_colors = len(colors)
palette_image = Image.new("RGB", (100 * num_colors, 100))
draw = ImageDraw.Draw(palette_image)
for i, color in enumerate(colors):
color = tuple(np.clip(color, 0, 255).astype(int))
draw.rectangle([i * 100, 0, (i + 1) * 100, 100], fill=color)
return palette_image
# Display Color Palette as Hex Codes
def display_palette(colors):
return ["#{:02x}{:02x}{:02x}".format(*np.clip(color, 0, 255)) for color in colors]
# Generate Image Caption Using Hugging Face BLIP
def generate_caption(image):
inputs = processor(images=image, return_tensors="pt")
output = caption_model.generate(**inputs)
return processor.decode(output[0], skip_special_tokens=True)
# Translate Caption to Arabic Using mBART
def translate_to_arabic(text):
tokenizer.src_lang = "en_XX"
encoded = tokenizer(text, return_tensors="pt")
generated_tokens = translation_model.generate(
**encoded,
forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"]
)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Gradio Interface Function (Combining Elements)
def process_image(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image_rgb = image.convert("RGB")
resized_image_np = load_image(image_rgb)
resized_image_pil = Image.fromarray(resized_image_np)
caption = generate_caption(image_rgb)
caption_arabic = translate_to_arabic(caption)
colors = extract_colors(resized_image_np, k=8)
color_palette = display_palette(colors)
palette_image = create_palette_image(colors)
bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
# Create Gradio Interface
with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>")
gr.Markdown(
"<p style='text-align: center;'>Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.</p>"
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload your image or select an example below")
submit_button = gr.Button("Submit")
gr.Examples(examples=examples, inputs=image_input, label="Example Images", examples_per_page=5)
with gr.Column(scale=1):
caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10)
palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2)
palette_image_output = gr.Image(type="pil", label="Color Palette")
resized_image_output = gr.Image(type="pil", label="Resized Image")
submit_button.click(
fn=process_image,
inputs=image_input,
outputs=[caption_output, palette_hex_output, palette_image_output, resized_image_output],
)
# Launch Gradio Interface
demo.launch()