danhtran2mind's picture
Create app.py
c8f6bca verified
raw
history blame
4.89 kB
import gradio as gr
from PIL import Image
import os
import numpy as np
import tensorflow as tf
import requests
from skimage.color import lab2rgb
# Model paths and mapping
load_model_paths = [
"ckpts/autoencoder/autoencoder_colorization_model.h5",
"ckpts/unet/unet_colorization_model.keras",
"ckpts/unet/unet_colorization_model.keras"
]
# Custom object needed by models
from models.auto_encoder_gray2color import SpatialAttention
# Model input size
WIDTH, HEIGHT = 512, 512
# Download models if they don't exist
def download_model(url, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
print(f"Downloading model from {url}...")
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("Download complete.")
# Helper to dynamically load a model
def load_model(model_path):
if not os.path.exists(model_path):
if "autoencoder" in model_path:
url = "https://huggingface.co/danhtran2mind/autoencoder-grayscale2color-landscape/resolve/main/ckpts/autoencoder_colorization_model.h5"
elif "unet" in model_path:
url = "https://huggingface.co/danhtran2mind/autoencoder-grayscale2color-landscape/resolve/main/ckpts/unet_colorization_model.keras"
else:
raise ValueError("Unknown model path for downloading.")
download_model(url, model_path)
print(f"Loading model from {model_path}...")
return tf.keras.models.load_model(
model_path,
custom_objects={'SpatialAttention': SpatialAttention}
)
# Dictionary of loaded models
loaded_models = {
"Autoencoder": load_model(load_model_paths[0]),
"U-Net v1": load_model(load_model_paths[1]),
"U-Net v2": load_model(load_model_paths[2])
}
def process_image(input_img, model_type):
model = loaded_models[model_type]
# Store original input dimensions
original_width, original_height = input_img.size
# Convert PIL Image to grayscale and resize to model input size
img = input_img.convert("L") # Grayscale
img = img.resize((WIDTH, HEIGHT)) # Resize to match model input
img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0 # Normalize
img_array = img_array[None, ..., 0:1] # Add batch dim (B, H, W, C)
# Predict a*b* channels
output_array = model.predict(img_array)
print("Model Output Shape:", output_array.shape)
L_channel = img_array[0, :, :, 0] * 100.0
ab_channels = output_array[0] * 128.0 # Denormalize ab to [-128, 128]
# Combine into Lab image
lab_image = np.stack([L_channel, ab_channels[:, :, 0], ab_channels[:, :, 1]], axis=-1)
# Convert to RGB
rgb_array = lab2rgb(lab_image)
rgb_array = np.clip(rgb_array, 0, 1) * 255.0
rgb_image = Image.fromarray(rgb_array.astype(np.uint8), 'RGB')
# Resize back to original resolution
rgb_image = rgb_image.resize((original_width, original_height), Image.Resampling.LANCZOS)
return rgb_image
custom_css = """
body {background: linear-gradient(135deg, #f0f4f8 0%, #d9e2ec 100%) !important;}
.gradio-container {background: transparent !important;}
h1, .gr-title {color: #007bff !important; font-family: 'Segoe UI', sans-serif;}
.gr-description {color: #333333 !important; font-size: 1.1em;}
.gr-input, .gr-output {border-radius: 18px !important; box-shadow: 0 4px 24px rgba(0,0,0,0.1);}
.gr-button {background: linear-gradient(90deg, #007bff 0%, #00c4cc 100%) !important; color: #fff !important; border: none !important; border-radius: 12px !important;}
"""
with gr.Blocks(theme="soft", css=custom_css) as demo:
gr.Markdown("<h1 style='text-align:center;'>πŸŒ„ Gray2Color Landscape Autoencoder</h1>")
gr.Markdown(
"<div style='font-size:1.15em;line-height:1.6em;text-align:center;'>"
"Transform grayscale landscape photos into vivid color using AI.<br>"
"Upload a grayscale image and select a model to begin!"
"</div>"
)
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Grayscale Landscape", image_mode="L")
image_output = gr.Image(type="pil", label="Colorized Output")
model_selector = gr.Dropdown(
choices=["Autoencoder", "U-Net v1", "U-Net v2"],
label="Select Model",
value="Autoencoder"
)
run_button = gr.Button("🎨 Colorize")
run_button.click(fn=process_image, inputs=[image_input, model_selector], outputs=image_output)
gr.Examples(
examples=[
["examples/example_input_1.jpg"],
["examples/example_input_2.jpg"]
],
inputs=[image_input],
outputs=image_output,
fn=lambda x: process_image(x, "Autoencoder"), # Default example model choice
cache_examples=True
)
if __name__ == "__main__":
demo.launch()