selfitcamera commited on
Commit
5d333f5
·
1 Parent(s): 61db459
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. app.py +115 -0
  3. requirements.txt +9 -0
  4. utils.py +113 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from mtcnn.mtcnn import MTCNN
3
+ from utils import *
4
+
5
+
6
+ face_detector = MTCNN()
7
+
8
+ # Description
9
+ title = r"""
10
+ <h1 align="center">IDM-VTON + Outfit Anyone in the Wild </h1>
11
+ """
12
+
13
+ description = r"""
14
+ This demo combines <b>IDM-VTON </b> and <b>Outfit Anyone in the Wild </b>
15
+ 1. Human pose detection and reconstruction using large human model from Outfit Anyone in the Wild.
16
+ 2. Use IDM-VTON for training-free try-on.
17
+ 3. Applying the refine network from Outfit Anyone in the Wild.<br>
18
+ """
19
+
20
+ css = """
21
+ .gradio-container {width: 85% !important}
22
+ """
23
+
24
+
25
+ def onClick(cloth_image, pose_image, category,
26
+ caption, request: gr.Request):
27
+ if pose_image is None:
28
+ yield None, f"no user image found !"
29
+ return None, "no user image found !"
30
+ elif cloth_image is None:
31
+ yield None, f"no cloth image found !"
32
+ return None, "no cloth image found !"
33
+ try:
34
+ faces = face_detector.detect_faces(pose_image[:,:,::-1])
35
+ if len(faces)==0:
36
+ print(client_ip, 'faces num is 0! ', flush=True)
37
+ yield None, "Fatal Error !!! No face detected in pose image !!! "
38
+ return None, "Fatal Error !!! No face detected in pose image !!! "
39
+ else:
40
+ x, y, w, h = faces[0]["box"]
41
+ H, W = pose_image.shape[:2]
42
+ max_face_ratio = 1/3.3
43
+ if w/W>max_face_ratio or h/H>max_face_ratio:
44
+ yield None, "Fatal Error !!! Headshot is not allowed in pose image!!!"
45
+ return None, "Fatal Error !!! Headshot is not allowed in pose image!!!"
46
+
47
+ uploads = upload_imgs(ApiUrl, UploadToken, cloth_image, pose_image)
48
+ if uploads is None:
49
+ yield None, "fail to upload"
50
+ return None, "fail to upload"
51
+
52
+ infId = publicFastSwap(ApiUrl, OpenId, ApiKey, uploads, category, caption)
53
+ if not infId:
54
+ yield None, "fail to public you task"
55
+ return None, "fail to public you task"
56
+
57
+ max_try = 30
58
+ wait_s = 3
59
+ yield None, "start to process, please wait..."
60
+ for i in range(max_try):
61
+ time.sleep(wait_s)
62
+ taskStatus = getTaskRes(ApiUrl, infId)
63
+ if taskStatus is None: continue
64
+
65
+ status = taskStatus['status']
66
+ if status in ['FAILED', 'CANCELLED', 'TIMED_OUT', ]:
67
+ yield None, f"task failed, query {i}, status {status}"
68
+ return None, f"task failed, query {i}, status {status}"
69
+ elif status in ['IN_QUEUE', 'IN_PROGRESS', 'IN_QUEUE', ]:
70
+ pass
71
+ yield None, f"task is on processing, query {i}, status {status}, please do not exit !!!"
72
+ elif status=='COMPLETED':
73
+ out = taskStatus['output']['job_results']['output1']
74
+ yield out, f"task is COMPLETED"
75
+ return out, f"{i} task COMPLETED"
76
+ yield None, "fail to query task.."
77
+ return None, "fail to query task.."
78
+
79
+
80
+ except Exception as e:
81
+ print(e)
82
+ return None, "fail to create task"
83
+
84
+
85
+ with gr.Blocks(css=css) as demo:
86
+ # description
87
+ gr.Markdown(title)
88
+ gr.Markdown(description)
89
+
90
+ with gr.Row():
91
+ with gr.Column():
92
+ with gr.Column():
93
+ cloth_image = gr.Image(value=None, type="numpy", label="cloth")
94
+ with gr.Column():
95
+ with gr.Column():
96
+ pose_image = gr.Image(value=None, type="numpy", label="user photo")
97
+ with gr.Column():
98
+ with gr.Column():
99
+ category = gr.Dropdown(value="upper_cloth", choices=["upper_cloth",
100
+ "lower_cloth", "full_body", "dresses"], interactive=True)
101
+ caption = gr.Textbox(value="", interactive=True, label='cloth caption')
102
+
103
+ info_text = gr.Textbox(value="", interactive=False, label='runtime information')
104
+ run_button = gr.Button(value="Run")
105
+ res_image = gr.Image(label="result image", value=None, type="filepath")
106
+
107
+ run_button.click(fn=onClick, inputs=[cloth_image, pose_image,
108
+ category, caption, ],
109
+ outputs=[res_image, info_text, ])
110
+
111
+ if __name__ == "__main__":
112
+
113
+ demo.queue(max_size=50)
114
+ demo.launch(server_name='0.0.0.0')
115
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ numpy
3
+ requests
4
+ gradio==3.41.2
5
+ gradio-client==0.5.0
6
+ mtcnn
7
+ tensorflow
8
+ func_timeout
9
+ httpx==0.24.1
utils.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import json
6
+ import random
7
+ import time
8
+ import requests
9
+ import func_timeout
10
+ import numpy as np
11
+ import gradio as gr
12
+
13
+
14
+ ApiUrl = os.environ['ApiUrl']
15
+ OpenId = os.environ['OpenId']
16
+ ApiKey = os.environ['ApiKey']
17
+ UploadToken = os.environ['UploadToken']
18
+
19
+
20
+ proj_dir = os.path.dirname(os.path.abspath(__file__))
21
+ data_dir = os.path.join(proj_dir, 'Datas')
22
+ # data_dir = "Datas"
23
+ tmpFolder = "tmp"
24
+ os.makedirs(tmpFolder, exist_ok=True)
25
+
26
+
27
+
28
+ def upload_imgs(apiUrl, UploadToken, cloth_image, pose_image):
29
+ folder = os.path.join(tmpFolder, str(random.randint(0, 100)))
30
+ os.makedirs(folder, exist_ok=True)
31
+ pose_path = os.path.join(folder, 'pose.jpg')
32
+ cloth_path = os.path.join(folder, 'cloth.jpg')
33
+ cv2.imwrite(pose_path, pose_image[:,:,::-1])
34
+ cv2.imwrite(cloth_path, cloth_image[:,:,::-1])
35
+
36
+ params = {'token':UploadToken,
37
+ 'input1':'pose.jpg',
38
+ 'input2':'cloth.jpg',
39
+ 'protocol':'https'}
40
+ session = requests.session()
41
+ ret = requests.post(f"{apiUrl}/upload", data=json.dumps(params))
42
+ if ret.status_code==200:
43
+ if 'upload1' in ret.json():
44
+ data = ret.json()
45
+ with open(cloth_path, 'rb') as file:
46
+ headers = {"Content-Type": 'image/jpeg'}
47
+ response = requests.put(data['upload2'], data=file, headers=headers)
48
+ if response.status_code == 200:
49
+ print("上传成功")
50
+ else:
51
+ print(f"上传失败,状态码: {response.status_code}, 响应内容: {response.text}")
52
+ return
53
+ with open(pose_path, 'rb') as file:
54
+ response = requests.put(data['upload1'], data=file, headers=headers)
55
+ if response.status_code == 200:
56
+ print("上传成功")
57
+ else:
58
+ print(f"上传失败,状态码: {response.status_code}, 响应内容: {response.text}")
59
+ return
60
+ if os.path.exists(pose_path): os.remove(pose_path)
61
+ if os.path.exists(cloth_path): os.remove(cloth_path)
62
+ return {'pose':data['upload1'], 'cloth':data['upload2']}
63
+
64
+ def publicFastSwap(apiUrl, openId, apiKey, uploads, category, caption):
65
+ if category=="upper_cloth":
66
+ category = 1
67
+ elif category=="lower_cloth":
68
+ category = 2
69
+ elif category=="dresses":
70
+ category = 3
71
+ elif category=="full_body":
72
+ category = 4
73
+ params = {'openId':OpenId, 'apiKey':ApiKey,
74
+ 'task_type':"10", 'image':str(uploads['pose']),
75
+ 'mask':str(uploads['cloth']),
76
+ 'param1':str(category), 'param2':str(caption),
77
+ 'param3':'', 'param4':'', 'param5':'', }
78
+ session = requests.session()
79
+ ret = requests.post(f"{ApiUrl}/public", data=json.dumps(params))
80
+ if ret.status_code==200:
81
+ if 'id' in ret.json():
82
+ print('public task successfully!')
83
+ return ret.json()['id']
84
+
85
+ def getTaskRes(apiUrl, taskId):
86
+ params = {'id':taskId, 'task_type':"10"}
87
+ session = requests.session()
88
+ ret = requests.post(f"{apiUrl}/status", data=json.dumps(params))
89
+ if ret.status_code==200:
90
+ if 'status' in ret.json():
91
+ return ret.json()
92
+ else:
93
+ print(ret.json(), ret.status_code, 'call status failed')
94
+ return None
95
+
96
+ @func_timeout.func_set_timeout(10)
97
+ def check_func(ip):
98
+ session = requests.session()
99
+ ret = requests.get(f"https://webapi-pc.meitu.com/common/ip_location?ip={ip}")
100
+ for k in ret.json()['data']:
101
+ nat = ret.json()['data'][k]['nation']
102
+ if nat.lower() in Regions.lower():
103
+ print(nat, 'invalid')
104
+ return False
105
+ else:
106
+ print(nat, 'valid')
107
+ return True
108
+ def check_warp(ip):
109
+ try:
110
+ return check_func(ip)
111
+ except Exception as e:
112
+ print(e)
113
+ return True