# app.py import numpy as np import gradio as gr from sklearn.cluster import KMeans from transformers import ( BlipProcessor, BlipForConditionalGeneration, MBartForConditionalGeneration, MBart50TokenizerFast, ) from PIL import Image, ImageDraw import requests from io import BytesIO # Load models globally at startup print("Loading models...") blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") mbart_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") mbart_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") print("Models loaded successfully.") # Download example images def download_example_images(): image_urls = [ # URL format: ("Image Description", "Image URL") ("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470?w=512"), ("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9?w=512"), ("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b?w=512"), ("Beach and Ocean", "https://images.unsplash.com/photo-1507525428034-b723cf961d3e?w=512"), ("Desert Dunes", "https://images.unsplash.com/photo-1501594907352-04cda38ebc29?w=512"), ] example_images = [] for idx, (description, url) in enumerate(image_urls, start=1): try: response = requests.get(url) if response.status_code == 200: img = Image.open(BytesIO(response.content)) img.save(f'example{idx}.jpg') example_images.append([f'example{idx}.jpg']) else: print(f"Failed to download image from {url}") except Exception as e: print(f"Exception occurred while downloading image: {e}") return example_images # Download example images and prepare examples list examples = download_example_images() # Load and Process the Entire Image def load_image(image): # Convert PIL image to numpy array (RGB) image_np = np.array(image.convert('RGB')) # Resize the image for better processing resized_image = image.resize((224, 224), resample=Image.LANCZOS) resized_image_np = np.array(resized_image) return resized_image_np # Extract Dominant Colors from the Image def extract_colors(image, k=8): # Flatten the image pixels = image.reshape(-1, 3) # Normalize pixel values to [0, 1] pixels = pixels / 255.0 # Ensure data type is float64 pixels = pixels.astype(np.float64) # Apply K-means clustering to find dominant colors kmeans = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300) kmeans.fit(pixels) # Convert normalized colors back to 0-255 scale colors = (kmeans.cluster_centers_ * 255).astype(int) return colors # Create an Image for the Color Palette def create_palette_image(colors): num_colors = len(colors) palette_height = 50 palette_width = 50 * num_colors palette_image = Image.new("RGB", (palette_width, palette_height)) draw = ImageDraw.Draw(palette_image) for i, color in enumerate(colors): # Ensure color values are within the valid range and integers color = tuple(np.clip(color, 0, 255).astype(int)) draw.rectangle([i * 50, 0, (i + 1) * 50, palette_height], fill=color) return palette_image # Display Color Palette as Hex Codes def display_palette(colors): hex_colors = [] for color in colors: # Ensure color values are within valid range and integers color = np.clip(color, 0, 255).astype(int) hex_color = "#{:02x}{:02x}{:02x}".format(color[0], color[1], color[2]) hex_colors.append(hex_color) return hex_colors # Generate Image Caption Using Hugging Face BLIP def generate_caption(image): inputs = blip_processor(images=image, return_tensors="pt") output = blip_model.generate(**inputs) caption = blip_processor.decode(output[0], skip_special_tokens=True) return caption # Translate Caption to Arabic Using mBART def translate_to_arabic(text): mbart_tokenizer.src_lang = "en_XX" encoded = mbart_tokenizer(text, return_tensors="pt") generated_tokens = mbart_model.generate( **encoded, forced_bos_token_id=mbart_tokenizer.lang_code_to_id["ar_AR"] ) translated_text = mbart_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] return translated_text # Gradio Interface Function (Combining Elements) def process_image(image): try: # Ensure input is a PIL Image if isinstance(image, np.ndarray): image = Image.fromarray(image) # Convert to RGB format for PIL processing image_rgb = image.convert("RGB") # Load and resize the entire image resized_image_np = load_image(image_rgb) # Convert resized image to PIL Image for Gradio output resized_image_pil = Image.fromarray(resized_image_np) # Generate caption using BLIP model caption = generate_caption(image_rgb) # Translate caption to Arabic caption_arabic = translate_to_arabic(caption) # Extract dominant colors from the entire image colors = extract_colors(resized_image_np, k=8) color_palette = display_palette(colors) # Create palette image palette_image = create_palette_image(colors) # Combine English and Arabic captions bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}" return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil except Exception as e: print(f"Error during processing: {e}") return "An error occurred during processing.", "", None, None # Create Gradio Interface using Blocks and add a submit button with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo: gr.Markdown("
Upload an image or select one of the example images below to generate a color palette and a description of the image in both English and Arabic.
""" ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload your image or select an example below") submit_button = gr.Button("Submit") gr.Examples( examples=examples, inputs=image_input, label="Example Images", examples_per_page=5, ) with gr.Column(scale=1): caption_output = gr.Textbox(label="Bilingual Caption", lines=5, max_lines=10) palette_hex_output = gr.Textbox(label="Color Palette Hex Codes", lines=2) palette_image_output = gr.Image(type="pil", label="Color Palette") resized_image_output = gr.Image(type="pil", label="Resized Image") submit_button.click( fn=process_image, inputs=image_input, outputs=[caption_output, palette_hex_output, palette_image_output, resized_image_output], ) # Launch Gradio Interface demo.launch()