face_seg / inference.py
davelop's picture
Create inference.py
7443712 verified
import io
import cv2
import numpy as np
import tensorflow as tf
from sklearn.cluster import KMeans
from collections import Counter
from scipy.spatial import KDTree
from webcolors import hex_to_rgb, rgb_to_hex
from PIL import Image
model = tf.keras.models.load_model("model.h5")
classes = [
"background", "skin", "left eyebrow", "right eyebrow",
"left eye", "right eye", "nose", "upper lip", "inner mouth",
"lower lip", "hair"
]
def face_skin_extract(pred, image_x):
output = np.zeros_like(image_x, dtype=np.uint8)
mask = (pred == 1)
output[mask] = image_x[mask]
return output
def extract_dom_color_kmeans(img):
mask = ~np.all(img == [0, 0, 0], axis=-1)
non_black_pixels = img[mask]
k_cluster = KMeans(n_clusters=3, n_init="auto")
k_cluster.fit(non_black_pixels)
n_pixels = len(k_cluster.labels_)
counter = Counter(k_cluster.labels_)
perc = {i: np.round(counter[i] / n_pixels, 2) for i in counter}
val = list(perc.values())
val.sort()
res = val[-1]
dominant_cluster_index = list(perc.keys())[list(perc.values()).index(res)]
rgb_list = k_cluster.cluster_centers_[dominant_cluster_index]
return rgb_list
def closest_tone_match(rgb_tuple):
skin_tones = {
'Monk 10': '#292420',
'Monk 9': '#3a312a',
'Monk 8': '#604134',
'Monk 7': '#825c43',
'Monk 6': '#a07e56',
'Monk 5': '#d7bd96',
'Monk 4': '#eadaba',
'Monk 3': '#f7ead0',
'Monk 2': '#f3e7db',
'Monk 1': '#f6ede4'
}
rgb_values = []
names = []
for monk in skin_tones:
names.append(monk)
rgb_values.append(hex_to_rgb(skin_tones[monk]))
kdt_db = KDTree(rgb_values)
distance, index = kdt_db.query(rgb_tuple)
monk_hex = skin_tones[names[index]]
derived_hex = rgb_to_hex((int(rgb_tuple[0]), int(rgb_tuple[1]), int(rgb_tuple[2])))
return names[index], monk_hex, derived_hex
def inference(inputs: bytes) -> dict:
nparr = np.frombuffer(inputs, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_x = cv2.resize(image, (512, 512))
image_norm = image_x / 255.0
image_norm = np.expand_dims(image_norm, axis=0).astype(np.float32)
pred = model.predict(image_norm)[0]
pred = np.argmax(pred, axis=-1).astype(np.int32)
face_skin = face_skin_extract(pred, image_x)
dominant_color_rgb = extract_dom_color_kmeans(face_skin) # This is an RGB tuple (floats)
monk_tone, monk_hex, derived_hex = closest_tone_match(
(dominant_color_rgb[0], dominant_color_rgb[1], dominant_color_rgb[2])
)
return {
"derived_hex_code": derived_hex,
"monk_hex": monk_hex,
"monk_skin_tone": monk_tone,
"dominant_rgb": dominant_color_rgb.tolist()
}