Spaces:
Running
Running
| import base64 | |
| import io | |
| import json | |
| import logging | |
| import boto3 | |
| from PIL import Image | |
| from botocore.config import Config | |
| from botocore.exceptions import ClientError | |
| import gradio as gr | |
| import os | |
| # Set up logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| # Custom exception for image errors | |
| class ImageError(Exception): | |
| def __init__(self, message): | |
| self.message = message | |
| model_id = 'amazon.nova-canvas-v1:0' | |
| aws_id = os.getenv('AWS_ID') | |
| aws_secret = os.getenv('AWS_SECRET') | |
| def process_and_encode_image(image, min_size=320, max_size=4096, max_pixels=4194304): | |
| if image is None: | |
| raise ValueError("Input image is required.") | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image) | |
| # Convert to RGB mode if necessary | |
| if image.mode not in ('RGB', 'RGBA'): | |
| image = image.convert('RGB') | |
| elif image.mode == 'RGBA': | |
| # Convert RGBA to RGB by compositing on white background | |
| background = Image.new('RGB', image.size, (255, 255, 255)) | |
| background.paste(image, mask=image.split()[3]) # Use alpha channel as mask | |
| image = background | |
| # Ensure 8-bit color depth | |
| if image.mode == 'RGB' and isinstance(image.getpixel((0,0)), tuple) and len(image.getpixel((0,0))) == 3: | |
| if not all(0 <= x <= 255 for x in image.getpixel((0,0))): | |
| image = image.convert('RGB') | |
| current_pixels = image.width * image.height | |
| # If image exceeds max pixels, scale it down while maintaining aspect ratio | |
| if current_pixels > max_pixels: | |
| aspect_ratio = image.width / image.height | |
| if aspect_ratio > 1: # Width > Height | |
| new_width = int((max_pixels * aspect_ratio) ** 0.5) | |
| new_height = int(new_width / aspect_ratio) | |
| else: # Height >= Width | |
| new_height = int((max_pixels / aspect_ratio) ** 0.5) | |
| new_width = int(new_height * aspect_ratio) | |
| image = image.resize((new_width, new_height), Image.LANCZOS) | |
| # Ensure dimensions are within valid range | |
| if image.width < min_size or image.width > max_size or image.height < min_size or image.height > max_size: | |
| new_width = min(max(image.width, min_size), max_size) | |
| new_height = min(max(image.height, min_size), max_size) | |
| image = image.resize((new_width, new_height), Image.LANCZOS) | |
| # Convert to bytes and encode to base64 | |
| image_bytes = io.BytesIO() | |
| # Save as PNG with maximum compatibility | |
| image.save(image_bytes, format='PNG', optimize=True) | |
| encoded_image = base64.b64encode(image_bytes.getvalue()).decode('utf8') | |
| return encoded_image | |
| # Function to generate an image using Amazon Nova Canvas model | |
| def generate_image(body): | |
| logger.info("Generating image with Amazon Nova Canvas model %s", model_id) | |
| # Configure the client with a longer timeout | |
| bedrock = boto3.client( | |
| service_name='bedrock-runtime', | |
| aws_access_key_id=aws_id, | |
| aws_secret_access_key=aws_secret, | |
| region_name='us-east-1', | |
| config=Config(read_timeout=300) # Add 5-minute timeout | |
| ) | |
| print(body) | |
| try: | |
| response = bedrock.invoke_model( | |
| body=body, | |
| modelId=model_id, | |
| accept="application/json", | |
| contentType="application/json" | |
| ) | |
| response_body = json.loads(response.get("body").read()) | |
| # Check for error before processing the image | |
| if "error" in response_body: | |
| raise ImageError(f"Image generation error. Error is {response_body['error']}") | |
| base64_image = response_body.get("images")[0] | |
| base64_bytes = base64_image.encode('ascii') | |
| image_bytes = base64.b64decode(base64_bytes) | |
| logger.info("Successfully generated image with Amazon Nova Canvas model %s", model_id) | |
| return image_bytes | |
| except ClientError as err: | |
| message = err.response["Error"]["Message"] | |
| logger.error("A client error occurred: %s", message) | |
| raise ImageError(f"Client error during image generation: {message}") | |
| # Function to display image from bytes | |
| def display_image(image_bytes): | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| return image | |
| # Gradio functions for each task | |
| def text_to_image(prompt, negative_text=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0): | |
| # Prepare the textToImageParams dictionary | |
| text_to_image_params = { | |
| "text": prompt | |
| } | |
| # Conditionally add negativeText if it is not None and not empty | |
| if negative_text: | |
| text_to_image_params["negativeText"] = negative_text | |
| body = json.dumps({ | |
| "taskType": "TEXT_IMAGE", | |
| "textToImageParams": text_to_image_params, | |
| "imageGenerationConfig": { | |
| "numberOfImages": 1, | |
| "height": height, | |
| "width": width, | |
| "quality": quality, | |
| "cfgScale": cfg_scale, | |
| "seed": seed | |
| } | |
| }) | |
| image_bytes = generate_image(body) | |
| return display_image(image_bytes) | |
| def inpainting(image, mask_prompt=None, mask_image=None, text=None, negative_text=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0): | |
| if image is not None: | |
| input_image = process_and_encode_image(image) | |
| else: | |
| raise ValueError("Input image is required.") | |
| if mask_image is not None: | |
| mask_image_encoded = process_and_encode_image(image) | |
| else: | |
| mask_image_encoded = None | |
| if not mask_prompt and not mask_image: | |
| raise ValueError("You must specify either maskPrompt or maskImage.") | |
| # Prepare the inPaintingParams dictionary | |
| if mask_prompt and mask_image_encoded: | |
| raise ValueError("You must specify either maskPrompt or maskImage, but not both.") | |
| if not mask_prompt and not mask_image_encoded: | |
| raise ValueError("You must specify either maskPrompt or maskImage.") | |
| # Prepare the inPaintingParams dictionary with the appropriate mask parameter | |
| in_painting_params = { | |
| "image": input_image | |
| } | |
| if mask_prompt: | |
| in_painting_params["maskPrompt"] = mask_prompt | |
| elif mask_image_encoded: | |
| in_painting_params["maskImage"] = mask_image_encoded | |
| if text: | |
| in_painting_params["text"] = text | |
| if negative_text: | |
| in_painting_params["negativeText"] = negative_text | |
| body = json.dumps({ | |
| "taskType": "INPAINTING", | |
| "inPaintingParams": in_painting_params, | |
| "imageGenerationConfig": { | |
| "numberOfImages": 1, | |
| "height": height, | |
| "width": width, | |
| "quality": quality, | |
| "cfgScale": cfg_scale, | |
| "seed": seed | |
| } | |
| }) | |
| image_bytes = generate_image(body) | |
| return display_image(image_bytes) | |
| def outpainting(image, mask_prompt=None, mask_image=None, text=None, negative_text=None, outpainting_mode="DEFAULT", height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0): | |
| if image is not None: | |
| input_image = process_and_encode_image(image) | |
| else: | |
| raise ValueError("Input image is required.") | |
| if mask_image is not None: | |
| mask_bytes = io.BytesIO() | |
| mask_image.save(mask_bytes, format='PNG') | |
| mask_image_encoded = base64.b64encode(mask_bytes.getvalue()).decode('utf8') | |
| else: | |
| mask_image_encoded = None | |
| if not mask_prompt and not mask_image: | |
| raise ValueError("You must specify either maskPrompt or maskImage.") | |
| # Prepare the outPaintingParams dictionary | |
| out_painting_params = { | |
| "image": input_image, | |
| "outPaintingMode": outpainting_mode | |
| } | |
| # Conditionally add parameters if they are not None | |
| if mask_image_encoded: | |
| out_painting_params["maskImage"] = mask_image_encoded | |
| elif mask_prompt: | |
| out_painting_params["maskPrompt"] = mask_prompt | |
| if text: | |
| out_painting_params["text"] = text | |
| if negative_text: | |
| out_painting_params["negativeText"] = negative_text | |
| body = json.dumps({ | |
| "taskType": "OUTPAINTING", | |
| "outPaintingParams": out_painting_params, | |
| "imageGenerationConfig": { | |
| "numberOfImages": 1, | |
| "height": height, | |
| "width": width, | |
| "quality": quality, | |
| "cfgScale": cfg_scale, | |
| "seed": seed | |
| } | |
| }) | |
| image_bytes = generate_image(body) | |
| return display_image(image_bytes) | |
| def image_variation(images, text=None, negative_text=None, similarity_strength=0.5, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0): | |
| encoded_images = [] | |
| for image_path in images: | |
| with open(image_path, "rb") as image_file: | |
| encoded_images.append(process_and_encode_image(image_file)) | |
| # Prepare the imageVariationParams dictionary | |
| image_variation_params = { | |
| "images": encoded_images, | |
| "similarityStrength": similarity_strength | |
| } | |
| # Conditionally add parameters if they are not None | |
| if text: | |
| image_variation_params["text"] = text | |
| if negative_text: | |
| image_variation_params["negativeText"] = negative_text | |
| body = json.dumps({ | |
| "taskType": "IMAGE_VARIATION", | |
| "imageVariationParams": image_variation_params, | |
| "imageGenerationConfig": { | |
| "numberOfImages": 1, | |
| "height": height, | |
| "width": width, | |
| "quality": quality, | |
| "cfgScale": cfg_scale, | |
| "seed": seed | |
| } | |
| }) | |
| image_bytes = generate_image(body) | |
| return display_image(image_bytes) | |
| def image_conditioning(condition_image, text, negative_text=None, control_mode="CANNY_EDGE", control_strength=0.7, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0): | |
| if condition_image is not None: | |
| condition_image_encoded = process_and_encode_image(condition_image) | |
| else: | |
| raise ValueError("Input image is required.") | |
| # Prepare the textToImageParams dictionary | |
| text_to_image_params = { | |
| "text": text, | |
| "conditionImage": condition_image_encoded, | |
| "controlMode": control_mode, | |
| "controlStrength": control_strength | |
| } | |
| # Conditionally add negativeText if it is not None | |
| if negative_text: | |
| text_to_image_params["negativeText"] = negative_text | |
| body = json.dumps({ | |
| "taskType": "TEXT_IMAGE", | |
| "textToImageParams": text_to_image_params, | |
| "imageGenerationConfig": { | |
| "numberOfImages": 1, | |
| "height": height, | |
| "width": width, | |
| "quality": quality, | |
| "cfgScale": cfg_scale, | |
| "seed": seed | |
| } | |
| }) | |
| image_bytes = generate_image(body) | |
| return display_image(image_bytes) | |
| def color_guided_content(text=None, reference_image=None, negative_text=None, colors=None, height=1024, width=1024, quality="standard", cfg_scale=8.0, seed=0): | |
| # Encode the reference image if provided | |
| if reference_image is not None: | |
| reference_image_encoded = process_and_encode_image(reference_image) | |
| else: | |
| reference_image_encoded = None | |
| if not colors: | |
| colors = "#FF5733,#33FF57,#3357FF,#FF33A1,#33FFF5,#FF8C33,#8C33FF,#33FF8C,#FF3333,#33A1FF" | |
| # Prepare the colorGuidedGenerationParams dictionary | |
| color_guided_generation_params = { | |
| "text": text, | |
| "colors": colors.split(',') | |
| } | |
| # Conditionally add parameters if they are not None | |
| if negative_text: | |
| color_guided_generation_params["negativeText"] = negative_text | |
| if reference_image_encoded: | |
| color_guided_generation_params["referenceImage"] = reference_image_encoded | |
| body = json.dumps({ | |
| "taskType": "COLOR_GUIDED_GENERATION", | |
| "colorGuidedGenerationParams": color_guided_generation_params, | |
| "imageGenerationConfig": { | |
| "numberOfImages": 1, | |
| "height": height, | |
| "width": width, | |
| "quality": quality, | |
| "cfgScale": cfg_scale, | |
| "seed": seed | |
| } | |
| }) | |
| image_bytes = generate_image(body) | |
| return display_image(image_bytes) | |
| def background_removal(image): | |
| input_image = process_and_encode_image(image) | |
| body = json.dumps({ | |
| "taskType": "BACKGROUND_REMOVAL", | |
| "backgroundRemovalParams": {"image": input_image} | |
| }) | |
| image_bytes = generate_image(body) | |
| return display_image(image_bytes) | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.HTML(""" | |
| <style> | |
| #component-0 { | |
| max-width: 800px; | |
| margin: 0 auto; | |
| } | |
| </style> | |
| """) | |
| gr.Markdown("# Amazon Nova Canvas Image Generation") | |
| with gr.Tab("Text to Image"): | |
| with gr.Column(): | |
| gr.Markdown("Generate an image from a text prompt using the Amazon Nova Canvas model.") | |
| prompt = gr.Textbox(label="Prompt", placeholder="Enter a text prompt (1-1024 characters)", max_lines=1) | |
| output = gr.Image() | |
| with gr.Accordion("Advanced Options", open=False): | |
| negative_text = gr.Textbox(label="Negative Prompt", placeholder="Enter text to exclude (1-1024 characters)", max_lines=1) | |
| height = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Height") | |
| width = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Width") | |
| quality = gr.Radio(choices=["standard", "premium"], value="standard", label="Quality") | |
| cfg_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, value=8.0, label="CFG Scale") | |
| seed = gr.Slider(minimum=1, maximum=2000, step=1, value=8, label="Seed") | |
| gr.Button("Generate").click(text_to_image, inputs=[prompt, negative_text, height, width, quality, cfg_scale, seed], outputs=output) | |
| with gr.Tab("Inpainting"): | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| <div style="text-align: center;"> | |
| Modify specific areas of your image using inpainting. Upload your image and choose one of two ways to specify the areas you want to edit: | |
| You can use a photo editing tool to draw masks (using pure black for areas to edit and pure white for areas to preserve) or | |
| use the Mask Prompt field to direct the model in how to infer the mask. | |
| </div> | |
| """) | |
| image = gr.Image(type='pil', label="Input Image") | |
| mask_prompt = gr.Textbox(label="Mask Prompt", placeholder="Describe regions to edit", max_lines=1) | |
| with gr.Accordion("Mask Image", open=False): | |
| text = gr.Textbox(label="Text", placeholder="Describe what to generate (1-1024 characters)", max_lines=1) | |
| mask_image = gr.Image(type='pil', label="Mask Image") | |
| output = gr.Image() | |
| with gr.Accordion("Advanced Options", open=False): | |
| negative_text = gr.Textbox(label="Negative Prompt", placeholder="Describe what not to include (1-1024 characters)", max_lines=1) | |
| width = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Width") | |
| height = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Height") | |
| quality = gr.Radio(choices=["standard", "premium"], value="standard", label="Quality") | |
| cfg_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, value=8.0, label="CFG Scale") | |
| seed = gr.Slider(minimum=1, maximum=2000, step=1, value=8, label="Seed") | |
| gr.Button("Generate").click(inpainting, inputs=[image, mask_prompt, mask_image, text, negative_text, height, width, quality, cfg_scale, seed], outputs=output) | |
| with gr.Tab("Outpainting"): | |
| with gr.Column(): | |
| gr.Markdown("Extend an image beyond its original borders using a mask and text prompt.") | |
| gr.Markdown(""" | |
| <div style="text-align: center;"> | |
| Modify areas outside of your image using outpainting. Upload your image and choose one of two ways to specify the areas you want to edit: | |
| You can use a photo editing tool to draw masks extended outside of an images original borders (using pure black for areas to edit and pure | |
| white for areas to preserve) or use the Mask Prompt field to direct the model in how to infer the mask. | |
| </div> | |
| """) | |
| image = gr.Image(type='pil', label="Input Image") | |
| mask_prompt = gr.Textbox(label="Mask Prompt", placeholder="Describe regions to edit", max_lines=1) | |
| with gr.Accordion("Mask Image", open=False): | |
| text = gr.Textbox(label="Text", placeholder="Describe what to generate (1-1024 characters)", max_lines=1) | |
| mask_image = gr.Image(type='pil', label="Mask Image") | |
| output = gr.Image() | |
| with gr.Accordion("Advanced Options", open=False): | |
| negative_text = gr.Textbox(label="Negative Prompt", placeholder="Describe what not to include (1-1024 characters)", max_lines=1) | |
| outpainting_mode = gr.Radio(choices=["DEFAULT", "PRECISE"], value="DEFAULT", label="Outpainting Mode") | |
| width = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Width") | |
| height = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Height") | |
| quality = gr.Radio(choices=["standard", "premium"], value="standard", label="Quality") | |
| cfg_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, value=8.0, label="CFG Scale") | |
| seed = gr.Slider(minimum=1, maximum=2000, step=1, value=8, label="Seed") | |
| gr.Button("Generate").click(outpainting, inputs=[image, mask_prompt, mask_image, text, negative_text, outpainting_mode, height, width, quality, cfg_scale, seed], outputs=output) | |
| with gr.Tab("Image Variation"): | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| <div style="text-align: center;"> | |
| Create a variation image based on up to 5 other images and a text description (optional). | |
| </div> | |
| """) | |
| images = gr.File(type='filepath', label="Input Images", file_count="multiple", file_types=["image"]) | |
| text = gr.Textbox(label="Text", placeholder="Enter a text prompt (1-1024 characters)", max_lines=1) | |
| output = gr.Image() | |
| with gr.Accordion("Advanced Options", open=False): | |
| negative_text = gr.Textbox(label="Negative Prompt", placeholder="Enter text to exclude (1-1024 characters)", max_lines=1) | |
| similarity_strength = gr.Slider(minimum=0.2, maximum=1.0, step=0.1, value=0.7, label="Similarity Strength") | |
| width = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Width") | |
| height = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Height") | |
| quality = gr.Radio(choices=["standard", "premium"], value="standard", label="Quality") | |
| cfg_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, value=8.0, label="CFG Scale") | |
| seed = gr.Slider(minimum=1, maximum=2000, step=1, value=8, label="Seed") | |
| gr.Button("Generate").click(image_variation, inputs=[images, text, negative_text, similarity_strength, height, width, quality, cfg_scale, seed], outputs=output) | |
| with gr.Tab("Image Conditioning"): | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| <div style="text-align: center;"> | |
| Generate an image conditioned on an input image and a text prompt (required). | |
| </div> | |
| """) | |
| condition_image = gr.Image(type='pil', label="Condition Image") | |
| text = gr.Textbox(label="Text", placeholder="Enter a text prompt (1-1024 characters)", max_lines=1) | |
| output = gr.Image() | |
| with gr.Accordion("Advanced Options", open=False): | |
| negative_text = gr.Textbox(label="Negative Prompt", placeholder="Describe what not to include (1-1024 characters)", max_lines=1) | |
| control_mode = gr.Radio(choices=["CANNY_EDGE", "SEGMENTATION"], value="CANNY_EDGE", label="Control Mode") | |
| control_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Control Strength") | |
| width = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Width") | |
| height = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Height") | |
| quality = gr.Radio(choices=["standard", "premium"], value="standard", label="Quality") | |
| cfg_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, value=8.0, label="CFG Scale") | |
| seed = gr.Slider(minimum=1, maximum=2000, step=1, value=8, label="Seed") | |
| gr.Button("Generate").click(image_conditioning, inputs=[condition_image, text, negative_text, control_mode, control_strength, height, width, quality, cfg_scale, seed], outputs=output) | |
| with gr.Tab("Color Guided Content"): | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| <div style="text-align: center;"> | |
| Generate an image using a color palette from a reference image or a text prompt. Starter colors are provided. | |
| </div> | |
| """) | |
| reference_image = gr.Image(type='pil', label="Reference Image") | |
| colors = gr.Textbox(label="Colors", placeholder="Enter up to 10 colors as hex values, e.g., #00FF00,#FCF2AB", max_lines=1) | |
| text = gr.Textbox(label="Text", placeholder="Enter a text prompt (1-1024 characters)", max_lines=1) | |
| output = gr.Image() | |
| with gr.Accordion("Advanced Options", open=False): | |
| negative_text = gr.Textbox(label="Negative Prompt", placeholder="Enter text to exclude (1-1024 characters)", max_lines=1) | |
| width = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Width") | |
| height = gr.Slider(minimum=256, maximum=2048, step=64, value=1024, label="Height") | |
| quality = gr.Radio(choices=["standard", "premium"], value="standard", label="Quality") | |
| cfg_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, value=8.0, label="CFG Scale") | |
| seed = gr.Slider(minimum=1, maximum=2000, step=1, value=8, label="Seed") | |
| gr.Button("Generate").click(color_guided_content, inputs=[text, reference_image, negative_text, colors, height, width, quality, cfg_scale, seed], outputs=output) | |
| with gr.Tab("Background Removal"): | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| <div style="text-align: center;"> | |
| Remove the background from an image. | |
| </div> | |
| """) | |
| image = gr.Image(type='pil', label="Input Image") | |
| output = gr.Image() | |
| gr.Button("Generate").click(background_removal, inputs=image, outputs=output) | |
| if __name__ == "__main__": | |
| demo.launch() |