File size: 5,028 Bytes
c8f6bca
 
 
 
 
 
 
d688c98
 
c8f6bca
d688c98
 
 
c8f6bca
d688c98
 
 
c8f6bca
 
d688c98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8f6bca
d688c98
c8f6bca
d688c98
c8f6bca
 
 
 
d688c98
 
 
 
 
 
 
c8f6bca
d688c98
 
c8f6bca
d688c98
 
 
c8f6bca
d688c98
 
c8f6bca
d688c98
 
 
 
c8f6bca
d688c98
c8f6bca
 
 
 
 
 
 
 
 
 
 
 
 
d688c98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8f6bca
d688c98
 
 
 
 
 
 
 
 
c8f6bca
 
d688c98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
from PIL import Image
import os
import numpy as np
import tensorflow as tf
import requests
from skimage.color import lab2rgb
from models.autoencoder_gray2color import SpatialAttention
from models.unet_gray2color import SelfAttentionLayer

WIDTH, HEIGHT = 512, 512

# Define model paths
load_model_paths = [
    "./ckpts/autoencoder/autoencoder_colorization_model.h5",
    "./ckpts/unet/unet_colorization_model.keras",
    "./ckpts/transformer/transformer_colorization_model.keras"
]

# Load models at startup
models = {}
print("Loading models...")
for path in load_model_paths:
    model_name = os.path.basename(os.path.dirname(path))
    if not os.path.exists(path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        url_map = {
            "autoencoder": "https://huggingface.co/danhtran2mind/autoencoder-grayscale2color-landscape/resolve/main/ckpts/best_model.h5",
            "unet": "https://example.com/unet_colorization_model.keras",  # Replace with actual URL
            "transformer": "https://example.com/transformer_colorization_model.keras"  # Replace with actual URL
        }
        if model_name in url_map:
            print(f"Downloading {model_name} model from {url_map[model_name]}...")
            with requests.get(url_map[model_name], stream=True) as r:
                r.raise_for_status()
                with open(path, "wb") as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            print(f"Download complete for {model_name}.")

    custom_objects = {
        "autoencoder": {'SpatialAttention': SpatialAttention},
        "unet": {'SelfAttentionLayer': SelfAttentionLayer},
        "transformer": None
    }
    print(f"Loading {model_name} model from {path}...")
    models[model_name] = tf.keras.models.load_model(
        path,
        custom_objects=custom_objects[model_name]
    )
print("All models loaded.")

def process_image(input_img, model_name):
    # Store original input dimensions
    original_width, original_height = input_img.size

    # Convert PIL Image to grayscale and resize to model input size
    img = input_img.convert("L")  # Convert to grayscale (single channel)
    img = img.resize((WIDTH, HEIGHT))  # Resize to 512x512 for model
    img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0  # Normalize to [0, 1]
    img_array = img_array[None, ..., 0:1]  # Add batch dimension, shape: (1, 512, 512, 1)

    # Select model
    selected_model = models[model_name.lower()]

    # Run inference
    output_array = selected_model.predict(img_array)  # Shape: (1, 512, 512, 2) for a*b*

    # Extract L* (grayscale input) and a*b* (model output)
    L_channel = img_array[0, :, :, 0] * 100.0  # Denormalize L* to [0, 100]
    ab_channels = output_array[0] * 128.0  # Denormalize a*b* to [-128, 128]

    # Combine L*, a*, b* into a 3-channel L*a*b* image
    lab_image = np.stack([L_channel, ab_channels[:, :, 0], ab_channels[:, :, 1]], axis=-1)  # Shape: (512, 512, 3)

    # Convert L*a*b* to RGB
    rgb_array = lab2rgb(lab_image)  # Convert to RGB, output in [0, 1]
    rgb_array = np.clip(rgb_array, 0, 1) * 255.0  # Scale to [0, 255]
    rgb_image = Image.fromarray(rgb_array.astype(np.uint8), mode="RGB")  # Create RGB PIL image

    # Resize output image to match input image resolution
    rgb_image = rgb_image.resize((original_width, original_height), Image.Resampling.LANCZOS)

    return rgb_image

custom_css = """
body {background: linear-gradient(135deg, #f0f4f8 0%, #d9e2ec 100%) !important;}
.gradio-container {background: transparent !important;}
h1, .gr-title {color: #007bff !important; font-family: 'Segoe UI', sans-serif;}
.gr-description {color: #333333 !important; font-size: 1.1em;}
.gr-input, .gr-output {border-radius: 18px !important; box-shadow: 0 4px 24px rgba(0,0,0,0.1);}
.gr-button {background: linear-gradient(90deg, #007bff 0%, #00c4cc 100%) !important; color: #fff !important; border: none !important; border-radius: 12px !important;}
"""

demo = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Upload Grayscale Landscape", image_mode="L"),
        gr.Dropdown(
            choices=["Autoencoder", "Unet", "Transformer"],
            label="Select Model",
            value="Autoencoder"
        )
    ],
    outputs=gr.Image(type="pil", label="Colorized Output"),
    title="🌄 Gray2Color Landscape Colorization",
    description=(
        "<div style='font-size:1.15em;line-height:1.6em;'>"
        "Transform your <b>grayscale landscape</b> photos into vivid color using advanced deep learning models.<br>"
        "Upload a grayscale image, select a model (Autoencoder, U-Net, or Transformer), and see the results!"
        "</div>"
    ),
    theme="soft",
    css=custom_css,
    allow_flagging="never",
    examples=[
        ["examples/example_input_1.jpg", "Autoencoder"],
        ["examples/example_input_2.jpg", "Unet"]
    ]
)

if __name__ == "__main__":
    demo.launch()