File size: 6,197 Bytes
0f7f5eb 6196c52 0f7f5eb 6196c52 0f7f5eb 37ae657 3286ec0 37ae657 3286ec0 6196c52 3286ec0 6196c52 3286ec0 6196c52 3286ec0 6196c52 3286ec0 6196c52 3286ec0 6196c52 3286ec0 0f7f5eb 3286ec0 0f7f5eb 3286ec0 4d2e566 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import os
import cv2
import gradio as gr
import mediapipe as mp
import numpy as np
from PIL import Image
from gradio_client import Client, handle_file
example_path = os.path.join(os.path.dirname(__file__), 'example')
garm_list = os.listdir(os.path.join(example_path, "cloth"))
garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
human_list = os.listdir(os.path.join(example_path, "human"))
human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=True)
mp_drawing = mp.solutions.drawing_utils
mp_pose_landmark = mp_pose.PoseLandmark
def detect_pose(image):
# Convert to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Run pose detection
result = pose.process(image_rgb)
keypoints = {}
if result.pose_landmarks:
# Draw landmarks on image
mp_drawing.draw_landmarks(image, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
# Get image dimensions
height, width, _ = image.shape
# Extract specific landmarks
landmark_indices = {
'left_shoulder': mp_pose_landmark.LEFT_SHOULDER,
'right_shoulder': mp_pose_landmark.RIGHT_SHOULDER,
'left_hip': mp_pose_landmark.LEFT_HIP,
'right_hip': mp_pose_landmark.RIGHT_HIP
}
for name, index in landmark_indices.items():
lm = result.pose_landmarks.landmark[index]
x, y = int(lm.x * width), int(lm.y * height)
keypoints[name] = (x, y)
# Draw a circle + label for debug
cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
cv2.putText(image, name, (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return image
def align_clothing(body_img, clothing_img):
image_rgb = cv2.cvtColor(body_img, cv2.COLOR_BGR2RGB)
result = pose.process(image_rgb)
output = body_img.copy()
if result.pose_landmarks:
h, w, _ = output.shape
# Extract key points
def get_point(landmark_id):
lm = result.pose_landmarks.landmark[landmark_id]
return int(lm.x * w), int(lm.y * h)
left_shoulder = get_point(mp_pose_landmark.LEFT_SHOULDER)
right_shoulder = get_point(mp_pose_landmark.RIGHT_SHOULDER)
left_hip = get_point(mp_pose_landmark.LEFT_HIP)
right_hip = get_point(mp_pose_landmark.RIGHT_HIP)
# Destination box (torso region)
dst_pts = np.array([
left_shoulder,
right_shoulder,
right_hip,
left_hip
], dtype=np.float32)
# Source box (clothing image corners)
src_h, src_w = clothing_img.shape[:2]
src_pts = np.array([
[0, 0],
[src_w, 0],
[src_w, src_h],
[0, src_h]
], dtype=np.float32)
# Compute perspective transform and warp
matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
warped_clothing = cv2.warpPerspective(clothing_img, matrix, (w, h), borderMode=cv2.BORDER_TRANSPARENT)
# Handle transparency
if clothing_img.shape[2] == 4:
alpha = warped_clothing[:, :, 3] / 255.0
for c in range(3):
output[:, :, c] = (1 - alpha) * output[:, :, c] + alpha * warped_clothing[:, :, c]
else:
output = cv2.addWeighted(output, 0.8, warped_clothing, 0.5, 0)
return output
def process_image(human_img_path, garm_img_path):
client = Client("franciszzj/Leffa")
result = client.predict(
src_image_path=handle_file(human_img_path),
ref_image_path=handle_file(garm_img_path),
ref_acceleration=False,
step=30,
scale=2.5,
seed=42,
vt_model_type="viton_hd",
vt_garment_type="upper_body",
vt_repaint=False,
api_name="/leffa_predict_vt"
)
print(result)
generated_image_path = result[0]
print("generated_image_path" + generated_image_path)
generated_image = Image.open(generated_image_path)
return generated_image
# Create the main interface
with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("Virtual Try-On"):
gr.HTML("<center><h1>Virtual Try-On</h1></center>")
gr.HTML("<center><p>Upload an image of a person and an image of a garment ✨</p></center>")
with gr.Row():
with gr.Column():
human_img = gr.Image(type="filepath", label='Human', interactive=True)
example = gr.Examples(
inputs=human_img,
examples_per_page=10,
examples=human_list_path
)
with gr.Column():
garm_img = gr.Image(label="Garment", type="filepath", interactive=True)
example = gr.Examples(
inputs=garm_img,
examples_per_page=8,
examples=garm_list_path)
with gr.Column():
image_out = gr.Image(label="Processed image", type="pil")
with gr.Row():
try_button = gr.Button(value="Try-on", variant='primary')
# Linking the button to the processing function
try_button.click(fn=process_image, inputs=[human_img, garm_img], outputs=image_out)
with gr.TabItem("Help"):
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>Thank You!</h1>
<p style="font-size: 18px; margin: 20px 0;">
Thank you for using our Virtual Try-On application. We appreciate your support and hope you enjoy trying on different outfits virtually!
</p>
<p style="font-size: 16px; margin: 20px 0;">
If you have any questions or need assistance, please don't hesitate to reach out to our support team.
</p>
</div>
""")
demo.launch(show_error=True)
|