File size: 3,816 Bytes
e7c4269
 
aef179f
 
 
e7c4269
 
aef179f
 
 
e7c4269
aef179f
 
 
 
 
4022ec1
e7c4269
 
aef179f
e7c4269
aef179f
 
e7c4269
aef179f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7c4269
aef179f
 
 
 
 
 
e7c4269
aef179f
 
 
 
 
 
 
 
 
 
 
e7c4269
aef179f
 
 
 
 
 
 
 
 
 
 
e7c4269
aef179f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80bdacd
4022ec1
4501d93
80bdacd
aef179f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import numpy as np
#import random
import spaces #[uncomment to use ZeroGPU]
#from diffusers import DiffusionPipeline
import torch

from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline

#import cv2
#import matplotlib.pyplot as plt
from PIL import Image
import os
import gc
import glob


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

GDINO_MODEL_NAME="IDEA-Research/grounding-dino-tiny" 
SAM_MODEL_NAME="facebook/sam-vit-base" 

GDINO=pipeline(model=GDINO_MODEL_NAME, task="zero-shot-object-detection", device=DEVICE)
SAM=AutoModelForMaskGeneration.from_pretrained(SAM_MODEL_NAME).to(DEVICE)
SAM_PROCESSOR=AutoProcessor.from_pretrained(SAM_MODEL_NAME)

SD_MODEL="diffusers/stable-diffusion-xl-1.0-inpainting-0.1" 
SD_PIPLINE = AutoPipelineForInpainting.from_pretrained(SD_MODEL, torch_dtype=torch.float16).to(DEVICE)
IP_ADAPTER="h94/IP-Adapter" 
SUB_FOLDER="sdxl_models"  
IP_WEIGHT_NAME="ip-adapter_sdxl.bin" 
SD_PIPLINE.load_ip_adapter(IP_ADAPTER, subfolder=SUB_FOLDER, weight_name=IP_WEIGHT_NAME)
IP_SCALE=0.6 
SD_PIPLINE.set_ip_adapter_scale(IP_SCALE)

GEN_STEPS=100 


def refine_masks(masks: torch.BoolTensor)->np.array:    
    masks = masks.permute(0, 2, 3, 1)
    masks = masks.float().mean(axis=-1)
    return masks.cpu().numpy()  


def get_boxes(detections:list)-> list:
    boxes = []
    for det in detections:
        boxes.append([det['box']['xmin'], det['box']['ymin'], 
                      det['box']['xmax'], det['box']['ymax']])
    return [boxes]


def get_mask(img:Image, prompt:str, d_model:pipeline, s_model:AutoModelForMaskGeneration,
             s_processor:AutoProcessor, device:str, threshold:float = 0.3)-> np.array:
    
    labels = [label if label.endswith(".") else label+"." for label in ['face', prompt]] 
    dets=d_model(img, candidate_labels=labels, threshold=threshold)    

    boxes = get_boxes(dets)
    inputs=s_processor(images=img, input_boxes=boxes, return_tensors="pt").to(DEVICE)
    outputs = s_model(**inputs)
    
    masks = s_processor.post_process_masks(
        masks=outputs.pred_masks,
        original_sizes=inputs.original_sizes,
        reshaped_input_sizes=inputs.reshaped_input_sizes
    )[0] 

    return refine_masks(masks) 


def generate_result(model_img:str, cloth_img:str,
                    masks: np.array, prompt:str, sd_pipline:AutoPipelineForInpainting, n_steps:int=100)->Image:
    
    width, height = model_img.size
    
    cloth_mask=masks[1] #np.array(masks[1],dtype=np.float32)
    generator = torch.Generator(device="cpu").manual_seed(4)
    images = sd_pipline(
        prompt=prompt,
        image=model_img, 
        mask_image=cloth_mask,
        ip_adapter_image=cloth_img, 
        generator=generator,
        num_inference_steps=n_steps,
        
    ).images
    
    return images[0].resize((width, height))


@spaces.GPU
def run(model_img:Image, cloth_img:Image, cloth_class:str, close_description:str)->Image:
    masks = get_mask(model_img, cloth_class, GDINO, SAM, SAM_PROCESSOR, DEVICE) #GSAM2)
    result = generate_result(model_img, cloth_img, masks, close_description, SD_PIPLINE, GEN_STEPS)
    gc.collect()
    torch.cuda.empty_cache()
    return result


gr.Interface(
    run,
    title = 'Virtual Try-On',
    inputs=[
        gr.Image(sources = 'upload', label='Model image', type = 'pil'),
        gr.Image(sources = 'upload', label='Cloth image', type = 'pil'),
        gr.Textbox(label = 'Cloth class'),
        gr.Textbox(label = 'Close description')
    ],
    outputs = [
        gr.Image()
        ],
    examples=[
        ["./examples/models/girl1.jpg", "./examples/clothes/t_short.jpg", "shirt", "black shirt"],        
    ]
).launch(debug=True,share=True)