File size: 4,893 Bytes
c8f6bca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
from PIL import Image
import os
import numpy as np
import tensorflow as tf
import requests
from skimage.color import lab2rgb

# Model paths and mapping
load_model_paths = [
    "ckpts/autoencoder/autoencoder_colorization_model.h5",
    "ckpts/unet/unet_colorization_model.keras",
    "ckpts/unet/unet_colorization_model.keras"
]

# Custom object needed by models
from models.auto_encoder_gray2color import SpatialAttention

# Model input size
WIDTH, HEIGHT = 512, 512

# Download models if they don't exist
def download_model(url, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    print(f"Downloading model from {url}...")
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(path, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    print("Download complete.")

# Helper to dynamically load a model
def load_model(model_path):
    if not os.path.exists(model_path):
        if "autoencoder" in model_path:
            url = "https://huggingface.co/danhtran2mind/autoencoder-grayscale2color-landscape/resolve/main/ckpts/autoencoder_colorization_model.h5"
        elif "unet" in model_path:
            url = "https://huggingface.co/danhtran2mind/autoencoder-grayscale2color-landscape/resolve/main/ckpts/unet_colorization_model.keras"
        else:
            raise ValueError("Unknown model path for downloading.")
        download_model(url, model_path)
    print(f"Loading model from {model_path}...")
    return tf.keras.models.load_model(
        model_path,
        custom_objects={'SpatialAttention': SpatialAttention}
    )

# Dictionary of loaded models
loaded_models = {
    "Autoencoder": load_model(load_model_paths[0]),
    "U-Net v1": load_model(load_model_paths[1]),
    "U-Net v2": load_model(load_model_paths[2])
}

def process_image(input_img, model_type):
    model = loaded_models[model_type]

    # Store original input dimensions
    original_width, original_height = input_img.size

    # Convert PIL Image to grayscale and resize to model input size
    img = input_img.convert("L")  # Grayscale
    img = img.resize((WIDTH, HEIGHT))  # Resize to match model input
    img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0  # Normalize
    img_array = img_array[None, ..., 0:1]  # Add batch dim (B, H, W, C)

    # Predict a*b* channels
    output_array = model.predict(img_array)
    print("Model Output Shape:", output_array.shape)

    L_channel = img_array[0, :, :, 0] * 100.0
    ab_channels = output_array[0] * 128.0  # Denormalize ab to [-128, 128]

    # Combine into Lab image
    lab_image = np.stack([L_channel, ab_channels[:, :, 0], ab_channels[:, :, 1]], axis=-1)

    # Convert to RGB
    rgb_array = lab2rgb(lab_image)
    rgb_array = np.clip(rgb_array, 0, 1) * 255.0
    rgb_image = Image.fromarray(rgb_array.astype(np.uint8), 'RGB')

    # Resize back to original resolution
    rgb_image = rgb_image.resize((original_width, original_height), Image.Resampling.LANCZOS)

    return rgb_image

custom_css = """
body {background: linear-gradient(135deg, #f0f4f8 0%, #d9e2ec 100%) !important;}
.gradio-container {background: transparent !important;}
h1, .gr-title {color: #007bff !important; font-family: 'Segoe UI', sans-serif;}
.gr-description {color: #333333 !important; font-size: 1.1em;}
.gr-input, .gr-output {border-radius: 18px !important; box-shadow: 0 4px 24px rgba(0,0,0,0.1);}
.gr-button {background: linear-gradient(90deg, #007bff 0%, #00c4cc 100%) !important; color: #fff !important; border: none !important; border-radius: 12px !important;}
"""

with gr.Blocks(theme="soft", css=custom_css) as demo:
    gr.Markdown("<h1 style='text-align:center;'>๐ŸŒ„ Gray2Color Landscape Autoencoder</h1>")
    gr.Markdown(
        "<div style='font-size:1.15em;line-height:1.6em;text-align:center;'>"
        "Transform grayscale landscape photos into vivid color using AI.<br>"
        "Upload a grayscale image and select a model to begin!"
        "</div>"
    )
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Grayscale Landscape", image_mode="L")
        image_output = gr.Image(type="pil", label="Colorized Output")
    model_selector = gr.Dropdown(
        choices=["Autoencoder", "U-Net v1", "U-Net v2"],
        label="Select Model",
        value="Autoencoder"
    )
    run_button = gr.Button("๐ŸŽจ Colorize")
    run_button.click(fn=process_image, inputs=[image_input, model_selector], outputs=image_output)

    gr.Examples(
        examples=[
            ["examples/example_input_1.jpg"],
            ["examples/example_input_2.jpg"]
        ],
        inputs=[image_input],
        outputs=image_output,
        fn=lambda x: process_image(x, "Autoencoder"),  # Default example model choice
        cache_examples=True
    )

if __name__ == "__main__":
    demo.launch()