Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
-
import
|
|
|
|
| 2 |
from PIL import Image
|
|
|
|
| 3 |
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
|
| 4 |
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
|
| 5 |
from src.unet_hacked_tryon import UNet2DConditionModel
|
|
@@ -15,7 +17,6 @@ from typing import List
|
|
| 15 |
import torch
|
| 16 |
import os
|
| 17 |
from transformers import AutoTokenizer
|
| 18 |
-
import spaces
|
| 19 |
import numpy as np
|
| 20 |
from utils_mask import get_mask_location
|
| 21 |
from torchvision import transforms
|
|
@@ -25,6 +26,7 @@ from preprocess.openpose.run_openpose import OpenPose
|
|
| 25 |
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
|
| 26 |
from torchvision.transforms.functional import to_pil_image
|
| 27 |
|
|
|
|
| 28 |
|
| 29 |
def pil_to_binary_mask(pil_image, threshold=0):
|
| 30 |
np_image = np.array(pil_image)
|
|
@@ -121,10 +123,15 @@ pipe = TryonPipeline.from_pretrained(
|
|
| 121 |
)
|
| 122 |
pipe.unet_encoder = UNet_Encoder
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
openpose_model.preprocessor.body_estimation.model.to(device)
|
| 129 |
pipe.to(device)
|
| 130 |
pipe.unet_encoder.to(device)
|
|
@@ -150,7 +157,7 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
|
|
| 150 |
if is_checked:
|
| 151 |
keypoints = openpose_model(human_img.resize((384,512)))
|
| 152 |
model_parse, _ = parsing_model(human_img.resize((384,512)))
|
| 153 |
-
mask, mask_gray = get_mask_location('hd',
|
| 154 |
mask = mask.resize((768,1024))
|
| 155 |
else:
|
| 156 |
mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
|
|
@@ -281,6 +288,7 @@ with image_blocks as demo:
|
|
| 281 |
with gr.Row(elem_id="prompt-container"):
|
| 282 |
with gr.Row():
|
| 283 |
prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
|
|
|
|
| 284 |
example = gr.Examples(
|
| 285 |
inputs=garm_img,
|
| 286 |
examples_per_page=8,
|
|
@@ -304,7 +312,7 @@ with image_blocks as demo:
|
|
| 304 |
|
| 305 |
|
| 306 |
|
| 307 |
-
try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
|
| 308 |
|
| 309 |
|
| 310 |
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('./')
|
| 3 |
from PIL import Image
|
| 4 |
+
import gradio as gr
|
| 5 |
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
|
| 6 |
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
|
| 7 |
from src.unet_hacked_tryon import UNet2DConditionModel
|
|
|
|
| 17 |
import torch
|
| 18 |
import os
|
| 19 |
from transformers import AutoTokenizer
|
|
|
|
| 20 |
import numpy as np
|
| 21 |
from utils_mask import get_mask_location
|
| 22 |
from torchvision import transforms
|
|
|
|
| 26 |
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
|
| 27 |
from torchvision.transforms.functional import to_pil_image
|
| 28 |
|
| 29 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 30 |
|
| 31 |
def pil_to_binary_mask(pil_image, threshold=0):
|
| 32 |
np_image = np.array(pil_image)
|
|
|
|
| 123 |
)
|
| 124 |
pipe.unet_encoder = UNet_Encoder
|
| 125 |
|
| 126 |
+
def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed, category):
|
| 127 |
+
category = int(category)
|
| 128 |
+
if category==0:
|
| 129 |
+
category='upper_body'
|
| 130 |
|
| 131 |
+
elif category==1:
|
| 132 |
+
category='lower_body'
|
| 133 |
+
else:
|
| 134 |
+
category='dresses'
|
| 135 |
openpose_model.preprocessor.body_estimation.model.to(device)
|
| 136 |
pipe.to(device)
|
| 137 |
pipe.unet_encoder.to(device)
|
|
|
|
| 157 |
if is_checked:
|
| 158 |
keypoints = openpose_model(human_img.resize((384,512)))
|
| 159 |
model_parse, _ = parsing_model(human_img.resize((384,512)))
|
| 160 |
+
mask, mask_gray = get_mask_location('hd', category, model_parse, keypoints)
|
| 161 |
mask = mask.resize((768,1024))
|
| 162 |
else:
|
| 163 |
mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
|
|
|
|
| 288 |
with gr.Row(elem_id="prompt-container"):
|
| 289 |
with gr.Row():
|
| 290 |
prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
|
| 291 |
+
category = gr.Textbox(placeholder="0 = upper body, 1 = lower body, 2 = full body", show_label=True)
|
| 292 |
example = gr.Examples(
|
| 293 |
inputs=garm_img,
|
| 294 |
examples_per_page=8,
|
|
|
|
| 312 |
|
| 313 |
|
| 314 |
|
| 315 |
+
try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed, category], outputs=[image_out,masked_img], api_name='tryon')
|
| 316 |
|
| 317 |
|
| 318 |
|