Update app.py
Browse files
app.py
CHANGED
|
@@ -5,14 +5,18 @@ import numpy as np
|
|
| 5 |
import tensorflow as tf
|
| 6 |
import requests
|
| 7 |
from skimage.color import lab2rgb
|
| 8 |
-
from tensorflow.keras.optimizers import Adam
|
| 9 |
from models.autoencoder_gray2color import SpatialAttention
|
| 10 |
from models.unet_gray2color import SelfAttentionLayer
|
| 11 |
|
| 12 |
# Set float32 policy
|
| 13 |
tf.keras.mixed_precision.set_global_policy('float32')
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Define model paths
|
| 18 |
load_model_paths = [
|
|
@@ -51,11 +55,10 @@ for path in load_model_paths:
|
|
| 51 |
models[model_name] = tf.keras.models.load_model(
|
| 52 |
path,
|
| 53 |
custom_objects=custom_objects[model_name],
|
| 54 |
-
compile=False
|
| 55 |
)
|
| 56 |
-
# Recompile the model
|
| 57 |
models[model_name].compile(
|
| 58 |
-
optimizer=Adam(learning_rate=7e-5),
|
| 59 |
loss=tf.keras.losses.MeanSquaredError()
|
| 60 |
)
|
| 61 |
print(f"{model_name} model loaded.")
|
|
@@ -65,30 +68,32 @@ print("All models loaded.")
|
|
| 65 |
def process_image(input_img, model_name):
|
| 66 |
# Store original input dimensions
|
| 67 |
original_width, original_height = input_img.size
|
|
|
|
|
|
|
| 68 |
# Convert PIL Image to grayscale and resize to model input size
|
| 69 |
-
img = input_img.convert("L")
|
| 70 |
-
img = img.resize((
|
| 71 |
-
img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
|
| 72 |
-
img_array = img_array[None, ..., 0:1] #
|
| 73 |
|
| 74 |
# Select model
|
| 75 |
selected_model = models[model_name.lower()]
|
| 76 |
# Run inference
|
| 77 |
-
output_array = selected_model.predict(img_array) # Shape: (1,
|
| 78 |
|
| 79 |
-
# Extract L*
|
| 80 |
L_channel = img_array[0, :, :, 0] * 100.0 # Denormalize L* to [0, 100]
|
| 81 |
ab_channels = output_array[0] * 128.0 # Denormalize a*b* to [-128, 128]
|
| 82 |
|
| 83 |
-
# Combine L*, a*, b*
|
| 84 |
-
lab_image = np.stack([L_channel, ab_channels[:, :, 0], ab_channels[:, :, 1]], axis=-1)
|
| 85 |
|
| 86 |
-
# Convert
|
| 87 |
-
rgb_array = lab2rgb(lab_image)
|
| 88 |
-
rgb_array = np.clip(rgb_array, 0, 1) * 255.0
|
| 89 |
-
rgb_image = Image.fromarray(rgb_array.astype(np.uint8), mode="RGB")
|
| 90 |
|
| 91 |
-
# Resize output
|
| 92 |
rgb_image = rgb_image.resize((original_width, original_height), Image.Resampling.LANCZOS)
|
| 93 |
return rgb_image
|
| 94 |
|
|
|
|
| 5 |
import tensorflow as tf
|
| 6 |
import requests
|
| 7 |
from skimage.color import lab2rgb
|
|
|
|
| 8 |
from models.autoencoder_gray2color import SpatialAttention
|
| 9 |
from models.unet_gray2color import SelfAttentionLayer
|
| 10 |
|
| 11 |
# Set float32 policy
|
| 12 |
tf.keras.mixed_precision.set_global_policy('float32')
|
| 13 |
|
| 14 |
+
# Model-specific input shapes
|
| 15 |
+
MODEL_INPUT_SHAPES = {
|
| 16 |
+
"autoencoder": (512, 512),
|
| 17 |
+
"unet": (1024, 1024),
|
| 18 |
+
"transformer": (1024, 1024)
|
| 19 |
+
}
|
| 20 |
|
| 21 |
# Define model paths
|
| 22 |
load_model_paths = [
|
|
|
|
| 55 |
models[model_name] = tf.keras.models.load_model(
|
| 56 |
path,
|
| 57 |
custom_objects=custom_objects[model_name],
|
| 58 |
+
compile=False
|
| 59 |
)
|
|
|
|
| 60 |
models[model_name].compile(
|
| 61 |
+
optimizer=tf.keras.optimizers.Adam(learning_rate=7e-5),
|
| 62 |
loss=tf.keras.losses.MeanSquaredError()
|
| 63 |
)
|
| 64 |
print(f"{model_name} model loaded.")
|
|
|
|
| 68 |
def process_image(input_img, model_name):
|
| 69 |
# Store original input dimensions
|
| 70 |
original_width, original_height = input_img.size
|
| 71 |
+
# Get model-specific input shape
|
| 72 |
+
width, height = MODEL_INPUT_SHAPES[model_name.lower()]
|
| 73 |
# Convert PIL Image to grayscale and resize to model input size
|
| 74 |
+
img = input_img.convert("L")
|
| 75 |
+
img = img.resize((width, height))
|
| 76 |
+
img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
|
| 77 |
+
img_array = img_array[None, ..., 0:1] # Shape: (1, height, width, 1)
|
| 78 |
|
| 79 |
# Select model
|
| 80 |
selected_model = models[model_name.lower()]
|
| 81 |
# Run inference
|
| 82 |
+
output_array = selected_model.predict(img_array) # Shape: (1, height, width, 2)
|
| 83 |
|
| 84 |
+
# Extract L* and a*b*
|
| 85 |
L_channel = img_array[0, :, :, 0] * 100.0 # Denormalize L* to [0, 100]
|
| 86 |
ab_channels = output_array[0] * 128.0 # Denormalize a*b* to [-128, 128]
|
| 87 |
|
| 88 |
+
# Combine L*, a*, b*
|
| 89 |
+
lab_image = np.stack([L_channel, ab_channels[:, :, 0], ab_channels[:, :, 1]], axis=-1)
|
| 90 |
|
| 91 |
+
# Convert to RGB
|
| 92 |
+
rgb_array = lab2rgb(lab_image)
|
| 93 |
+
rgb_array = np.clip(rgb_array, 0, 1) * 255.0
|
| 94 |
+
rgb_image = Image.fromarray(rgb_array.astype(np.uint8), mode="RGB")
|
| 95 |
|
| 96 |
+
# Resize output to original resolution
|
| 97 |
rgb_image = rgb_image.resize((original_width, original_height), Image.Resampling.LANCZOS)
|
| 98 |
return rgb_image
|
| 99 |
|