|
import gradio as gr |
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
from model import Generator |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = Generator().to(device) |
|
model.load_state_dict(torch.load('generator.pth', map_location=device)) |
|
model.eval() |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.Grayscale(num_output_channels=1), |
|
transforms.ToTensor() |
|
]) |
|
|
|
postprocess = transforms.ToPILImage() |
|
|
|
def colorize_image(input_image): |
|
input_tensor = preprocess(input_image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
output_tensor = model(input_tensor) |
|
output_image = postprocess(output_tensor.squeeze(0).cpu().clamp(0, 1)) |
|
return output_image |
|
|
|
def reset(): |
|
return None, None |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# π¨ Image Colorization App") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Upload your grayscale image", type="pil") |
|
clear_button = gr.Button("π Reset / Clear") |
|
download_button = gr.File(label="Download Colorized Image") |
|
with gr.Column(): |
|
output_image = gr.Image(label="Colorized Image") |
|
|
|
colorize_btn = gr.Button("β¨ Colorize Image") |
|
|
|
colorize_btn.click( |
|
colorize_image, |
|
inputs=input_image, |
|
outputs=output_image |
|
) |
|
|
|
clear_button.click( |
|
reset, |
|
inputs=[], |
|
outputs=[input_image, output_image] |
|
) |
|
|
|
|
|
def prepare_download(image): |
|
if image: |
|
path = "colorized_output.png" |
|
image.save(path) |
|
return path |
|
else: |
|
return None |
|
|
|
output_image.change( |
|
prepare_download, |
|
inputs=output_image, |
|
outputs=download_button |
|
) |
|
|
|
demo.launch() |
|
|