File size: 5,626 Bytes
969899b 35d0b13 969899b b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df b9713ec 521a5df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Install necessary libraries
pip install torch transformers gradio Pillow scikit-learn requests
import numpy as np
import gradio as gr
from sklearn.cluster import KMeans
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
MBartForConditionalGeneration,
MBart50TokenizerFast,
)
from PIL import Image, ImageDraw
import requests
from io import BytesIO
# Load Hugging Face models globally
print("Loading Hugging Face models...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
# Download example images
def download_example_images():
image_urls = [
("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"),
("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"),
("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"),
("Beach and Ocean", "https://images.unsplash.com/photo-1507525428034-b723cf961d3e"),
("Desert Dunes", "https://images.unsplash.com/photo-1501594907352-04cda38ebc29"),
]
example_images = []
for idx, (description, url) in enumerate(image_urls, start=1):
try:
response = requests.get(url)
response.raise_for_status()
img = Image.open(BytesIO(response.content))
img.save(f'example{idx}.jpg')
example_images.append([f'example{idx}.jpg'])
except requests.RequestException as e:
print(f"Failed to download image from {url}: {e}")
return example_images
# Download example images and prepare examples list
examples = download_example_images()
# Load and Process the Entire Image
def load_image(image):
resized_image = image.resize((300, 300), resample=Image.LANCZOS)
return np.array(resized_image)
# Extract Dominant Colors from the Image
def extract_colors(image, k=8):
pixels = image.reshape(-1, 3) / 255.0
kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300)
kmeans.fit(pixels)
return (kmeans.cluster_centers_ * 255).astype(int)
# Create an Image for the Color Palette
def create_palette_image(colors):
num_colors = len(colors)
palette_image = Image.new("RGB", (100 * num_colors, 100))
draw = ImageDraw.Draw(palette_image)
for i, color in enumerate(colors):
color = tuple(np.clip(color, 0, 255).astype(int))
draw.rectangle([i * 100, 0, (i + 1) * 100, 100], fill=color)
return palette_image
# Display Color Palette as Hex Codes
def display_palette(colors):
return ["#{:02x}{:02x}{:02x}".format(*np.clip(color, 0, 255)) for color in colors]
# Generate Image Caption Using Hugging Face BLIP
def generate_caption(image):
inputs = processor(images=image, return_tensors="pt")
output = caption_model.generate(**inputs)
return processor.decode(output[0], skip_special_tokens=True)
# Translate Caption to Arabic Using mBART
def translate_to_arabic(text):
tokenizer.src_lang = "en_XX"
encoded = tokenizer(text, return_tensors="pt")
generated_tokens = translation_model.generate(
**encoded,
forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"]
)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Gradio Interface Function (Combining Elements)
def process_image(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image_rgb = image.convert("RGB")
resized_image_np = load_image(image_rgb)
resized_image_pil = Image.fromarray(resized_image_np)
caption = generate_caption(image_rgb)
caption_arabic = translate_to_arabic(caption)
colors = extract_colors(resized_image_np, k=8)
color_palette = display_palette(colors)
palette_image = create_palette_image(colors)
bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
# Create Gradio Interface
with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
gr.Markdown("<h1 style='text-align: center;'>Palette Generator from Image with Image Captioning</h1>")
gr.Markdown(
"<p style='text-align: center;'>Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.</p>"
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload your image or select an example below")
submit_button = gr.Button("Submit")
gr.Examples(examples=examples, inputs=image_input, label="Example Images", examples_per_page=5)
with gr.Column(scale=1):
caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10)
palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2)
palette_image_output = gr.Image(type="pil", label="Color Palette")
resized_image_output = gr.Image(type="pil", label="Resized Image")
submit_button.click(
fn=process_image,
inputs=image_input,
outputs=[caption_output, palette_hex_output, palette_image_output, resized_image_output],
)
# Launch Gradio Interface
demo.launch()
|