|
import gradio as gr |
|
import torch |
|
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel |
|
from PIL import Image |
|
|
|
|
|
controlnet = ControlNetModel.from_pretrained( |
|
"rsortino/ColorizeNet", |
|
torch_dtype=torch.float16, |
|
use_safetensors=True |
|
) |
|
|
|
|
|
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1", |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
pipe = pipe.to(device) |
|
|
|
|
|
pipe.safety_checker = lambda images, **kwargs: (images, False) |
|
|
|
def colorize(image: Image.Image) -> Image.Image: |
|
image = image.convert("RGB").resize((512, 512)) |
|
result = pipe( |
|
prompt="A realistic colorized version of this image.", |
|
image=image, |
|
control_image=image, |
|
strength=1.0, |
|
guidance_scale=9.0, |
|
num_inference_steps=30 |
|
) |
|
return result.images[0] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## π¨ ColorizeNet - Grayscale to Color Image") |
|
gr.Markdown("Upload a grayscale image. The model will generate a realistic colorized version.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_img = gr.Image(label="Grayscale Input", type="pil") |
|
submit_btn = gr.Button("Colorize") |
|
|
|
with gr.Column(): |
|
output_img = gr.Image(label="Colorized Output", type="pil") |
|
download_btn = gr.Button("Download") |
|
|
|
def handle_colorize(img): |
|
return colorize(img) |
|
|
|
def download_image(img): |
|
return img |
|
|
|
submit_btn.click(fn=handle_colorize, inputs=input_img, outputs=output_img) |
|
download_btn.click(fn=download_image, inputs=output_img, outputs=gr.File()) |
|
|
|
demo.launch() |
|
|