Spaces:
Running
Running
import asyncio | |
import base64 | |
import logging | |
import os | |
import time | |
import cv2 | |
import gradio as gr | |
import httpx | |
import numpy as np | |
import requests | |
from gradio.themes.utils import sizes | |
# LOGGING | |
logger = logging.getLogger("LookSwap") | |
logger.setLevel(logging.INFO) | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
# IMAGE ASSETS | |
ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets") | |
WATERMARK = cv2.imread(os.path.join(ASSETS_DIR, "watermark.png"), cv2.IMREAD_UNCHANGED) | |
WATERMARK = cv2.resize(WATERMARK, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_AREA) | |
NSFW = os.path.join(ASSETS_DIR, "nsfw.webp") | |
# API CONFIG | |
FASHN_API_URL = os.environ.get("FASHN_ENPOINT_URL") | |
FASHN_API_KEY = os.environ.get("FASHN_API_KEY") | |
assert FASHN_API_URL, "Please set the FASHN_ENPOINT_URL environment variable" | |
assert FASHN_API_KEY, "Please set the FASHN_API_KEY environment variable" | |
# ----------------- HELPER FUNCTIONS ----------------- # | |
def add_watermark(image: np.array, watermark: np.array, offset: int = 5) -> np.array: | |
"""Adds a watermark to the image at the bottom right corner with a given offset.""" | |
image_height, image_width = image.shape[:2] | |
watermark_height, watermark_width = watermark.shape[:2] | |
# Calculate the position of the watermark in the bottom right corner, with a slight offset | |
x_offset = image_width - watermark_width - offset | |
y_offset = image_height - watermark_height - offset | |
# Separate the watermark into its color and alpha channels | |
overlay_color = watermark[:, :, :3] | |
overlay_mask = watermark[:, :, 3] | |
# Blend the watermark with the image | |
for c in range(0, 3): | |
image[y_offset : y_offset + watermark_height, x_offset : x_offset + watermark_width, c] = overlay_color[ | |
:, :, c | |
] * (overlay_mask / 255.0) + image[ | |
y_offset : y_offset + watermark_height, x_offset : x_offset + watermark_width, c | |
] * ( | |
1.0 - overlay_mask / 255.0 | |
) | |
return image | |
def opencv_load_image_from_http(url: str) -> np.ndarray: | |
"""Loads an image from a given URL using HTTP GET.""" | |
with requests.get(url) as response: | |
response.raise_for_status() | |
image_data = np.frombuffer(response.content, np.uint8) | |
image = cv2.imdecode(image_data, cv2.IMREAD_COLOR) | |
return image | |
def resize_image(img: np.array, short_axis_target: int = 512) -> np.array: | |
"""Resizes an image to keep the aspect ratio with the shortest axis not exceeding a target size.""" | |
height, width = img.shape[:2] | |
scale_factor = short_axis_target / min(height, width) | |
resized_img = cv2.resize(img, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_AREA) | |
return resized_img | |
def encode_img_to_base64(img: np.array) -> str: | |
"""Encodes an image as a JPEG in Base64 format.""" | |
img = cv2.imencode(".jpg", img)[1].tobytes() | |
img = base64.b64encode(img).decode("utf-8") | |
img = f"data:image/jpeg;base64,{img}" | |
return img | |
def parse_checkboxes(checkboxes): | |
checkboxes = [checkbox.lower().replace(" ", "_") for checkbox in checkboxes] | |
checkboxes = {checkbox: True for checkbox in checkboxes} | |
return checkboxes | |
def verify_aspect_ratio(img: np.array, prefix: str = "Model"): | |
height, width = img.shape[:2] | |
aspect_ratio = width / height | |
if aspect_ratio < 0.5: | |
raise gr.Error(f"{prefix} image W:H aspect ratio is too low. Use 2:3 or 3:4 for best results.") | |
elif aspect_ratio > 0.8: | |
raise gr.Error(f"{prefix} image W:H aspect ratio is too high. Use 2:3 or 3:4 for best results.") | |
# ----------------- CORE FUNCTION ----------------- # | |
CATEGORY_API_MAPPING = {"Top": "tops", "Bottom": "bottoms", "Full-body": "one-pieces"} | |
async def get_tryon_result(model_image, garment_image, category, model_checkboxes, request: gr.Request): | |
logger.info("Starting new try-on request...") | |
if request: | |
client_ip = request.headers.get("x-forwarded-for") or request.client.host | |
# verify aspect ratio of the input images | |
verify_aspect_ratio(model_image, "Model") | |
# verify_aspect_ratio(garment_image, "Garment") | |
# preprocessing: convert to RGB, resize, encode to base64 | |
model_image, garment_image = map(lambda x: cv2.cvtColor(x, cv2.COLOR_RGB2BGR), [model_image, garment_image]) | |
model_image, garment_image = map(resize_image, [model_image, garment_image]) | |
model_image, garment_image = map(encode_img_to_base64, [model_image, garment_image]) | |
# prepare data for API request | |
category = CATEGORY_API_MAPPING[category] | |
data = { | |
"model_image": model_image, | |
"garment_image": garment_image, | |
"category": category, | |
**parse_checkboxes(model_checkboxes), | |
} | |
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {FASHN_API_KEY}"} | |
# make API request | |
start_time = time.time() | |
async with httpx.AsyncClient() as client: | |
response = await client.post(f"{FASHN_API_URL}/run", headers=headers, json=data, timeout=httpx.Timeout(300.0)) | |
if response.is_error: | |
raise gr.Error(f"API request failed: {response.text}") | |
pred_id = response.json().get("id") | |
logger.info(f"Prediction ID: {pred_id}") | |
# poll the status of the prediction | |
while True: | |
current_time = time.time() | |
elapsed_time = current_time - start_time | |
if elapsed_time > 180: # 3 minutes | |
raise gr.Error("Maximum polling time exceeded.") | |
status_response = await client.get( | |
f"{FASHN_API_URL}/status/{pred_id}", headers=headers, timeout=httpx.Timeout(10) | |
) | |
if status_response.is_error: | |
raise Exception(f"Status polling failed: {status_response.text}") | |
status_data = status_response.json() | |
if status_data["status"] not in ["starting", "in_queue", "processing", "completed"]: | |
error = status_data.get("error") | |
error_msg = f"Prediction failed: {error}" | |
if "NSFW" in error: | |
if request: | |
gr.Warning(f"NSFW attempt IP address: {client_ip}") | |
return NSFW | |
raise gr.Error(error_msg) | |
logger.info(f"Prediction status: {status_data['status']}") | |
if status_data["status"] == "completed": | |
break | |
await asyncio.sleep(3) | |
# get the result image and add a watermark | |
result_img = opencv_load_image_from_http(status_data["output"][0]) | |
result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB) | |
result_img = add_watermark(result_img, WATERMARK) | |
return result_img | |
# ----------------- GRADIO UI ----------------- # | |
with open("banner.html", "r") as file: | |
banner = file.read() | |
with open("tips.html", "r") as file: | |
tips = file.read() | |
with open("footer.html", "r") as file: | |
footer = file.read() | |
with open("docs.html", "r") as file: | |
docs = file.read() | |
CUSTOM_CSS = """ | |
.image-container img { | |
max-width: 192px; | |
max-height: 288px; | |
margin: 0 auto; | |
border-radius: 0px; | |
.gradio-container {background-color: #fafafa} | |
""" | |
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radius_md)) as demo: | |
gr.HTML(banner) | |
gr.HTML(tips) | |
with gr.Row(): | |
with gr.Column(): | |
model_image = gr.Image(label="Model Image", type="numpy", format="png") | |
# create a checkbox to toggle "remove accessories" | |
model_checkboxes = gr.CheckboxGroup( | |
choices=["Remove Accessories", "Restore Hands", "Cover Feet"], label="Additional Controls", type="value" | |
) | |
example_model = gr.Examples( | |
inputs=model_image, | |
examples_per_page=10, | |
examples=[ | |
os.path.join(ASSETS_DIR, "models", img) for img in os.listdir(os.path.join(ASSETS_DIR, "models")) | |
], | |
) | |
with gr.Column(): | |
garment_image = gr.Image(label="Garment Image", type="numpy", format="png") | |
category = gr.Radio(choices=["Top", "Bottom", "Full-body"], label="Select Category", value="Top") | |
example_garment = gr.Examples( | |
inputs=garment_image, | |
examples_per_page=10, | |
examples=[ | |
os.path.join(ASSETS_DIR, "garments", img) | |
for img in os.listdir(os.path.join(ASSETS_DIR, "garments")) | |
], | |
) | |
with gr.Column(): | |
result_image = gr.Image(label="Try-on Result", format="png") | |
run_button = gr.Button("Run") | |
gr.HTML(docs) | |
run_button.click( | |
fn=get_tryon_result, | |
inputs=[model_image, garment_image, category, model_checkboxes], | |
outputs=[result_image], | |
) | |
gr.HTML(footer) | |
if __name__ == "__main__": | |
ip = requests.get("http://ifconfig.me/ip", timeout=1).text.strip() | |
logger.info(f"VM IP address: {ip}") | |
demo.launch(share=False) | |