Spaces:
Runtime error
Runtime error
File size: 4,891 Bytes
c42a584 6c89940 c42a584 42cd559 c42a584 972667b c42a584 a3ce2c7 5dfcfce 2ab4370 5dfcfce a3ce2c7 c42a584 7153b76 c42a584 8a47d13 e05d8b5 c42a584 e05d8b5 c42a584 972667b 2ab4370 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import argparse
import cv2
import numpy as np
import os
from tqdm import tqdm
import torch
from basicsr.archs.ddcolor_arch import DDColor
import torch.nn.functional as F
import gradio as gr
from gradio_imageslider import ImageSlider
import uuid
model_path = "pytorch_model.pt"
input_size = 512
model_size = 'large'
# Create Image Colorization Pipeline
class ImageColorizationPipeline(object):
def __init__(self, model_path, input_size=256, model_size='large'):
self.input_size = input_size
if torch.cuda.is_available():
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')
if model_size == 'tiny':
self.encoder_name = 'convnext-t'
else:
self.encoder_name = 'convnext-l'
self.decoder_type = "MultiScaleColorDecoder"
if self.decoder_type == 'MultiScaleColorDecoder':
self.model = DDColor(
encoder_name=self.encoder_name,
decoder_name='MultiScaleColorDecoder',
input_size=[self.input_size, self.input_size],
num_output_channels=2,
last_norm='Spectral',
do_normalize=False,
num_queries=100,
num_scales=3,
dec_layers=9,
).to(self.device)
else:
self.model = DDColor(
encoder_name=self.encoder_name,
decoder_name='SingleColorDecoder',
input_size=[self.input_size, self.input_size],
num_output_channels=2,
last_norm='Spectral',
do_normalize=False,
num_queries=256,
).to(self.device)
self.model.load_state_dict(
torch.load(model_path, map_location=torch.device('cpu'))['params'],
strict=False)
self.model.eval()
@torch.no_grad()
def process(self, img):
self.height, self.width = img.shape[:2]
img = (img / 255.0).astype(np.float32)
orig_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1] # (h, w, 1)
# resize rgb image -> lab -> get grey -> rgb
img = cv2.resize(img, (self.input_size, self.input_size))
img_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]
img_gray_lab = np.concatenate((img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)
img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)
tensor_gray_rgb = torch.from_numpy(img_gray_rgb.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device)
output_ab = self.model(tensor_gray_rgb).cpu() # (1, 2, self.height, self.width)
# resize ab -> concat original l -> rgb
output_ab_resize = F.interpolate(output_ab, size=(self.height, self.width))[0].float().numpy().transpose(1, 2, 0)
output_lab = np.concatenate((orig_l, output_ab_resize), axis=-1)
output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR)
output_img = (output_bgr * 255.0).round().astype(np.uint8)
return output_img
colorizer = ImageColorizationPipeline(model_path=model_path,
input_size=input_size,
model_size=model_size)
def colorize_image(image):
"""Colorizes a grayscale image using the DDColor model."""
# Convert image to grayscale if needed
img_array = np.array(image)
if len(img_array.shape) == 3 and img_array.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Colorize the image
colorized_img = colorizer.process(image)
# Convert colorized image to PIL format
colorized_img = Image.fromarray(colorized_img)
return colorized_img
# Create inference function for gradio app
def colorize(img):
image_out = colorizer.process(img)
# Generate a unique filename using UUID
unique_imgfilename = str(uuid.uuid4()) + '.png'
cv2.imwrite(unique_imgfilename, image_out)
return (img, unique_imgfilename)
def clear_images():
return None, None
dark_theme = gr.themes.Default(
primary_hue="teal",
secondary_hue="gray",
neutral_hue="slate",
).set(
body_background_fill="#000000" # Set background to black
)
custom_css = """
footer {
display: none !important;
}
"""
# Gradio demo using the Image-Slider custom component
with gr.Blocks(theme=dark_theme,css=custom_css) as demo:
with gr.Row():
with gr.Column():
bw_image = gr.Image(label='Black and White Input Image')
btn = gr.Button('Convert')
with gr.Column():
col_image_slider =ImageSlider(position=0.5,
label='Colored Image with Slider-view')
clear = gr.Button("Clear")
btn.click(colorize, bw_image, col_image_slider)
clear.click(clear_images, outputs=[bw_image, col_image_slider])
demo.launch(show_api=False)
|