File size: 4,891 Bytes
c42a584
 
 
 
 
 
 
 
 
 
 
 
6c89940
c42a584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42cd559
 
 
c42a584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
972667b
 
 
c42a584
a3ce2c7
 
 
 
 
 
 
5dfcfce
2ab4370
 
 
5dfcfce
a3ce2c7
c42a584
7153b76
c42a584
 
 
8a47d13
e05d8b5
c42a584
 
 
e05d8b5
c42a584
972667b
2ab4370
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
import cv2
import numpy as np
import os
from tqdm import tqdm
import torch
from basicsr.archs.ddcolor_arch import DDColor
import torch.nn.functional as F
import gradio as gr
from gradio_imageslider import ImageSlider
import uuid

model_path = "pytorch_model.pt"
input_size = 512
model_size = 'large'


# Create Image Colorization Pipeline
class ImageColorizationPipeline(object):

    def __init__(self, model_path, input_size=256, model_size='large'):

        self.input_size = input_size
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        if model_size == 'tiny':
            self.encoder_name = 'convnext-t'
        else:
            self.encoder_name = 'convnext-l'

        self.decoder_type = "MultiScaleColorDecoder"

        if self.decoder_type == 'MultiScaleColorDecoder':
            self.model = DDColor(
                encoder_name=self.encoder_name,
                decoder_name='MultiScaleColorDecoder',
                input_size=[self.input_size, self.input_size],
                num_output_channels=2,
                last_norm='Spectral',
                do_normalize=False,
                num_queries=100,
                num_scales=3,
                dec_layers=9,
            ).to(self.device)
        else:
            self.model = DDColor(
                encoder_name=self.encoder_name,
                decoder_name='SingleColorDecoder',
                input_size=[self.input_size, self.input_size],
                num_output_channels=2,
                last_norm='Spectral',
                do_normalize=False,
                num_queries=256,
            ).to(self.device)

        self.model.load_state_dict(
            torch.load(model_path, map_location=torch.device('cpu'))['params'],
            strict=False)
        self.model.eval()

    @torch.no_grad()
    def process(self, img):
        self.height, self.width = img.shape[:2]
        img = (img / 255.0).astype(np.float32)
        orig_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]  # (h, w, 1)

        # resize rgb image -> lab -> get grey -> rgb
        img = cv2.resize(img, (self.input_size, self.input_size))
        img_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]
        img_gray_lab = np.concatenate((img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)
        img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)

        tensor_gray_rgb = torch.from_numpy(img_gray_rgb.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device)
        output_ab = self.model(tensor_gray_rgb).cpu()  # (1, 2, self.height, self.width)

        # resize ab -> concat original l -> rgb
        output_ab_resize = F.interpolate(output_ab, size=(self.height, self.width))[0].float().numpy().transpose(1, 2, 0)
        output_lab = np.concatenate((orig_l, output_ab_resize), axis=-1)
        output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR)

        output_img = (output_bgr * 255.0).round().astype(np.uint8)

        return output_img

colorizer = ImageColorizationPipeline(model_path=model_path,
                                      input_size=input_size,
                                      model_size=model_size)

def colorize_image(image):
    """Colorizes a grayscale image using the DDColor model."""

    # Convert image to grayscale if needed
    img_array = np.array(image)
    if len(img_array.shape) == 3 and img_array.shape[2] == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Colorize the image
    colorized_img = colorizer.process(image)

    # Convert colorized image to PIL format
    colorized_img = Image.fromarray(colorized_img)

    return colorized_img

# Create inference function for gradio app
def colorize(img):
  image_out = colorizer.process(img)
  # Generate a unique filename using UUID
  unique_imgfilename = str(uuid.uuid4()) + '.png'
  cv2.imwrite(unique_imgfilename, image_out)
  return (img, unique_imgfilename)
    
def clear_images():
    return None, None

dark_theme = gr.themes.Default(
    primary_hue="teal",
    secondary_hue="gray",
    neutral_hue="slate",
).set(
    body_background_fill="#000000"  # Set background to black
)
custom_css = """
    footer {
        display: none !important;
    }
"""

# Gradio demo using the Image-Slider custom component
with gr.Blocks(theme=dark_theme,css=custom_css) as demo:
  with gr.Row():
    with gr.Column():
      bw_image = gr.Image(label='Black and White Input Image')
      btn = gr.Button('Convert')
      
    with gr.Column():
      col_image_slider =ImageSlider(position=0.5,
                                     label='Colored Image with Slider-view')
      clear = gr.Button("Clear")
  btn.click(colorize, bw_image, col_image_slider)
  clear.click(clear_images, outputs=[bw_image, col_image_slider])
demo.launch(show_api=False)