danhtran2mind commited on
Commit
e89a371
·
verified ·
1 Parent(s): 437b632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -5,14 +5,18 @@ import numpy as np
5
  import tensorflow as tf
6
  import requests
7
  from skimage.color import lab2rgb
8
- from tensorflow.keras.optimizers import Adam
9
  from models.autoencoder_gray2color import SpatialAttention
10
  from models.unet_gray2color import SelfAttentionLayer
11
 
12
  # Set float32 policy
13
  tf.keras.mixed_precision.set_global_policy('float32')
14
 
15
- WIDTH, HEIGHT = 512, 512
 
 
 
 
 
16
 
17
  # Define model paths
18
  load_model_paths = [
@@ -51,11 +55,10 @@ for path in load_model_paths:
51
  models[model_name] = tf.keras.models.load_model(
52
  path,
53
  custom_objects=custom_objects[model_name],
54
- compile=False # Skip optimizer state
55
  )
56
- # Recompile the model
57
  models[model_name].compile(
58
- optimizer=Adam(learning_rate=7e-5),
59
  loss=tf.keras.losses.MeanSquaredError()
60
  )
61
  print(f"{model_name} model loaded.")
@@ -65,30 +68,32 @@ print("All models loaded.")
65
  def process_image(input_img, model_name):
66
  # Store original input dimensions
67
  original_width, original_height = input_img.size
 
 
68
  # Convert PIL Image to grayscale and resize to model input size
69
- img = input_img.convert("L") # Convert to grayscale
70
- img = img.resize((WIDTH, HEIGHT)) # Resize to 512x512
71
- img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0 # Normalize to [0, 1]
72
- img_array = img_array[None, ..., 0:1] # Add batch dimension, shape: (1, 512, 512, 1)
73
 
74
  # Select model
75
  selected_model = models[model_name.lower()]
76
  # Run inference
77
- output_array = selected_model.predict(img_array) # Shape: (1, 512, 512, 2) for a*b*
78
 
79
- # Extract L* (grayscale input) and a*b* (model output)
80
  L_channel = img_array[0, :, :, 0] * 100.0 # Denormalize L* to [0, 100]
81
  ab_channels = output_array[0] * 128.0 # Denormalize a*b* to [-128, 128]
82
 
83
- # Combine L*, a*, b* into a 3-channel L*a*b* image
84
- lab_image = np.stack([L_channel, ab_channels[:, :, 0], ab_channels[:, :, 1]], axis=-1) # Shape: (512, 512, 3)
85
 
86
- # Convert L*a*b* to RGB
87
- rgb_array = lab2rgb(lab_image) # Convert to RGB, output in [0, 1]
88
- rgb_array = np.clip(rgb_array, 0, 1) * 255.0 # Scale to [0, 255]
89
- rgb_image = Image.fromarray(rgb_array.astype(np.uint8), mode="RGB") # Create RGB PIL image
90
 
91
- # Resize output image to match input resolution
92
  rgb_image = rgb_image.resize((original_width, original_height), Image.Resampling.LANCZOS)
93
  return rgb_image
94
 
 
5
  import tensorflow as tf
6
  import requests
7
  from skimage.color import lab2rgb
 
8
  from models.autoencoder_gray2color import SpatialAttention
9
  from models.unet_gray2color import SelfAttentionLayer
10
 
11
  # Set float32 policy
12
  tf.keras.mixed_precision.set_global_policy('float32')
13
 
14
+ # Model-specific input shapes
15
+ MODEL_INPUT_SHAPES = {
16
+ "autoencoder": (512, 512),
17
+ "unet": (1024, 1024),
18
+ "transformer": (1024, 1024)
19
+ }
20
 
21
  # Define model paths
22
  load_model_paths = [
 
55
  models[model_name] = tf.keras.models.load_model(
56
  path,
57
  custom_objects=custom_objects[model_name],
58
+ compile=False
59
  )
 
60
  models[model_name].compile(
61
+ optimizer=tf.keras.optimizers.Adam(learning_rate=7e-5),
62
  loss=tf.keras.losses.MeanSquaredError()
63
  )
64
  print(f"{model_name} model loaded.")
 
68
  def process_image(input_img, model_name):
69
  # Store original input dimensions
70
  original_width, original_height = input_img.size
71
+ # Get model-specific input shape
72
+ width, height = MODEL_INPUT_SHAPES[model_name.lower()]
73
  # Convert PIL Image to grayscale and resize to model input size
74
+ img = input_img.convert("L")
75
+ img = img.resize((width, height))
76
+ img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
77
+ img_array = img_array[None, ..., 0:1] # Shape: (1, height, width, 1)
78
 
79
  # Select model
80
  selected_model = models[model_name.lower()]
81
  # Run inference
82
+ output_array = selected_model.predict(img_array) # Shape: (1, height, width, 2)
83
 
84
+ # Extract L* and a*b*
85
  L_channel = img_array[0, :, :, 0] * 100.0 # Denormalize L* to [0, 100]
86
  ab_channels = output_array[0] * 128.0 # Denormalize a*b* to [-128, 128]
87
 
88
+ # Combine L*, a*, b*
89
+ lab_image = np.stack([L_channel, ab_channels[:, :, 0], ab_channels[:, :, 1]], axis=-1)
90
 
91
+ # Convert to RGB
92
+ rgb_array = lab2rgb(lab_image)
93
+ rgb_array = np.clip(rgb_array, 0, 1) * 255.0
94
+ rgb_image = Image.fromarray(rgb_array.astype(np.uint8), mode="RGB")
95
 
96
+ # Resize output to original resolution
97
  rgb_image = rgb_image.resize((original_width, original_height), Image.Resampling.LANCZOS)
98
  return rgb_image
99