WeShop's picture
oauth
eb39cff
raw
history blame
15.8 kB
import gradio as gr
import hmac
import hashlib
import time
import os
import requests
from io import BytesIO
from PIL import Image
import uuid
import base64
example_path = os.path.join(os.path.dirname(__file__), 'assets')
clothing_list = os.listdir(os.path.join(example_path, "clothing"))
clothing_list_path = [os.path.join(example_path, "clothing", clothing) for clothing in clothing_list]
base_url = os.getenv('base_url')
upload_image_url = os.getenv('upload_image_url')
save_mask_url = os.getenv('save_mask_url')
create_save_task_url = os.getenv('create_save_task_url')
execute_task_url = os.getenv('execute_task_url')
query_task_url = os.getenv('query_task_url')
secret_key = os.getenv('secret_key')
agent_version = os.getenv('agent_version')
agent_name = os.getenv('agent_name')
app_id = os.getenv('app_id')
def parse_response(response, state='default'):
data = {}
msg = ''
if response.status_code == 200:
try:
datas = response.json()
if datas:
data = datas.get("data")
if state == 'default':
if not data:
msg = datas.get("msg")
if not msg:
msg = "Field error."
elif state == 'saveMask':
success = datas.get("success")
code = datas.get("code")
if success and code == "1001":
data = True
else:
msg = "The parsing result is empty."
except Exception as e:
msg = f"parse error: {repr(e)}."
else:
msg = f'request error.'
return data, msg
def generate_signature(key, did, timestamp):
data = f"{did}:{timestamp}{app_id}"
h = hmac.new(key.encode(), data.encode(), hashlib.sha256)
return h.hexdigest()
def url_to_image(url, ip):
headers = {
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
'X-Forwarded-For': ip
}
try:
response = requests.get(url, headers=headers, timeout=30)
except:
return None
if response.status_code == 200:
img = Image.open(BytesIO(response.content))
return img
return None
def start_task(task_id, did, ip):
timestamp = str(int(time.time()))
signature = generate_signature(
key=secret_key,
did=did,
timestamp=timestamp
)
headers = {
'Did': did,
'X-Timestamp': timestamp,
'X-Signature': signature,
'X-Forwarded-For': ip,
'X-AppId': app_id,
}
data = {
"agentVersion": agent_version,
"agentName": agent_name,
"taskId": task_id,
"runFreeAsFallback": False
}
response = requests.post(base_url + execute_task_url, json=data, headers=headers)
data, msg = parse_response(response)
return data, msg
def create_task(image_url, did, ip):
timestamp = str(int(time.time()))
signature = generate_signature(
key=secret_key,
did=did,
timestamp=timestamp
)
headers = {
'Did': did,
'X-Timestamp': timestamp,
'X-Signature': signature,
'X-Forwarded-For': ip,
'X-AppId': app_id,
}
data = {
"agentVersion": agent_version,
"agentName": agent_name,
"image": image_url
}
response = requests.post(base_url + create_save_task_url, json=data, headers=headers)
data, msg = parse_response(response)
return data, msg
def query_task(task_id, execution_id, did, ip):
timestamp = str(int(time.time()))
signature = generate_signature(
key=secret_key,
did=did,
timestamp=timestamp
)
headers = {
'Did': did,
'X-Timestamp': timestamp,
'X-Signature': signature,
'X-Forwarded-For': ip,
'X-AppId': app_id,
}
data = {
"agentVersion": agent_version,
"agentName": agent_name,
"taskId": task_id,
"executionId": execution_id,
}
response = requests.post(base_url + query_task_url, json=data, headers=headers)
data, msg = parse_response(response)
return data, msg
def upload_image(image, did, ip):
if image is None:
return None
image_format = image.format if image.format else "PNG"
mime_type = f"image/{image_format.lower()}"
timestamp = str(int(time.time()))
signature = generate_signature(
key=secret_key,
did=did,
timestamp=timestamp
)
with BytesIO() as m_img:
image.save(m_img, format=image_format)
m_img.seek(0)
files = {'image': (f"main_image.{image_format.lower()}", m_img, mime_type)}
headers = {
'Did': did,
'X-Timestamp': timestamp,
'X-Signature': signature,
'X-Forwarded-For': ip,
'X-AppId': app_id,
}
response = requests.post(base_url + upload_image_url, files=files, headers=headers)
data, msg = parse_response(response)
return data, msg
def mask_image_save(task_id, mask, did, ip):
timestamp = str(int(time.time()))
signature = generate_signature(
key=secret_key,
did=did,
timestamp=timestamp
)
headers = {
'Did': did,
'X-Timestamp': timestamp,
'X-Signature': signature,
'X-Forwarded-For': ip,
'X-AppId': app_id,
}
data = {
"taskId": task_id,
"mask": mask,
"needGreyscale": True,
}
response = requests.post(base_url + save_mask_url, json=data, headers=headers)
data, msg = parse_response(response, state='saveMask')
return data, msg
def load_description(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
return content
def extract_and_binarize_alpha_channel(image, threshold=10):
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
alpha = image.split()[-1]
new_image = Image.new("L", image.size)
new_image.putdata(alpha.getdata())
binary_data = [(255 if pixel >= threshold else 0) for pixel in new_image.getdata()]
new_image.putdata(binary_data)
return new_image
else:
return None
def image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return f"data:image/png;base64,{img_base64}"
def generate_image(edit_image_infos, did, request: gr.Request):
if not did:
did = str(uuid.uuid4())
if edit_image_infos is None or not isinstance(edit_image_infos, dict):
m = "Please upload the main image before generating."
return gr.Warning(m), did
main_image = None
mask_image = None
for edit_image_key, edit_image_value in edit_image_infos.items():
if edit_image_key == 'background':
if isinstance(edit_image_value, Image.Image):
main_image = edit_image_value
elif edit_image_key == 'layers':
if edit_image_value and isinstance(edit_image_value, list) and isinstance(edit_image_value[0], Image.Image):
mask_image = extract_and_binarize_alpha_channel(edit_image_value[0])
if not main_image or not mask_image:
m = "Unable to parse image data."
return gr.Warning(m), did
if all(pixel == 0 for pixel in mask_image.getdata()):
m = "Please use the brush tool to mark the areas on the image that require detailed hand retouching."
return gr.Warning(m), did
client_ip = request.client.host
x_forwarded_for = request.headers.get('x-forwarded-for')
if x_forwarded_for:
client_ip = x_forwarded_for
upload_image_data, upload_image_msg = upload_image(
image=main_image,
did=did,
ip=client_ip
)
if not upload_image_data:
return gr.Warning(upload_image_msg), did
image_url = upload_image_data.get("image")
if not image_url:
m = 'Upload image failed.'
return gr.Warning(m), did
create_task_data, create_task_msg = create_task(
image_url=image_url,
did=did,
ip=client_ip
)
if not create_task_data:
return gr.Warning(create_task_msg), did
task_id = create_task_data.get("taskId")
show_image = create_task_data.get("showImage")
if not task_id or not show_image:
m = 'Create task failed.'
return gr.Warning(m), did
mask_image_save_data, mask_image_save_msg = mask_image_save(
task_id=task_id,
mask=image_to_base64(mask_image),
did=did,
ip=client_ip
)
if not mask_image_save_data:
return gr.Warning(mask_image_save_msg), did
start_task_data, start_task_msg = start_task(
task_id=task_id,
did=did,
ip=client_ip
)
if not start_task_data:
return gr.Warning(start_task_msg), did
execution_id = start_task_data.get("executionId")
if not execution_id:
m = "The task failed to start."
return gr.Warning(m), did
start_time = int(time.time())
while True:
m = "Query task failed."
query_task_data, query_task_msg = query_task(
task_id=task_id,
execution_id=execution_id,
did=did,
ip=client_ip
)
if not query_task_data:
return gr.Warning(query_task_msg), did
executions = query_task_data.get("executions")
if not executions:
return gr.Warning(m), did
results = executions[0].get("result")
if not results:
return gr.Warning(m), did
status = results[0].get("status")
if status == "Failed":
m = "The person image does not match your garment. It is recommended to change to a different one."
return gr.Warning(m), did
elif status == "Success" or status == "Blocked":
img = results[0].get("image")
if img and str(img).strip() != "":
return url_to_image(img, ip=client_ip), did
end_time = int(time.time())
if end_time - start_time > 3600:
m = 'Query task timeout.'
return gr.Warning(m), did
time.sleep(2)
def process_show_case_image(image1, image2, result_image):
return image1, {'background': image1, 'layers': [image2], 'composite': image1}, result_image
css = """
.image-container img {
max-height: 500px;
width: auto;
}
#example-images img {
border-radius: 10px
}
#example-images .gallery-item {
border: none;
}
#example-images .container {
border: none;
}
.hide-buttons .source-selection {
display: none;
}
#example-images .gallery {
display: flex;
flex-wrap: wrap;
}
#example-images .gallery-item .container{
width: 100%;
max-width: 100%;
max-height: 100% !important;
height: 100% !important;
}
#example-images .gallery-item {
flex: 0 0 30%;
max-width: 30%;
box-sizing: border-box;
display: flex;
text-align: center;
justify-content: center;
}
.middleware {
display: none;
}
@media (max-width: 767px) {
#example-res-images th {
font-size: 12px;
word-wrap: break-word;
word-break: break-word;
white-space: normal;
overflow-wrap: break-word;
}
}
#example-res-images .tr-head {
display: grid !important;
grid-template-columns: 1fr 1fr;
}
#example-res-images .tr-head th:nth-child(1),
#example-res-images .tr-head th:nth-child(2) {
grid-column: 1;
}
#example-res-images .tr-head th:nth-child(2) {
display: none !important;
}
#example-res-images .tr-head th:last-child {
grid-column: 2;
grid-row: 1;
}
#example-res-images .tr-body {
display: grid !important;
grid-template-columns: 1fr 1fr;
position: relative;
}
#example-res-images .tr-body td:nth-child(1),
#example-res-images .tr-body td:nth-child(2) {
grid-column: 1;
grid-row: 1;
position: relative;
}
#example-res-images .tr-body td:last-child {
grid-column: 2;
grid-row: 1;
}
}
"""
with gr.Blocks(css=css) as WeShop:
current_did = gr.State(value='')
gr.HTML(load_description("assets/title.html"))
with gr.Row():
with gr.Column():
gr.Markdown("#### Step 1: Upload an image needing hand enhancement")
main_image_input = gr.ImageEditor(
height="500px",
type="pil",
label="Main Image",
brush=gr.Brush(
default_size=30,
colors=["rgba(117, 48, 254, 0.5)"],
color_mode="fixed",
default_color="rgba(117, 48, 254, 0.5)",
),
eraser=gr.Eraser(
default_size=30,
),
layers=False,
elem_classes=["image-container", "hide-buttons"]
)
main_example = gr.Examples(
inputs=main_image_input,
examples_per_page=12,
examples=clothing_list_path,
elem_id="example-images",
outputs=main_image_input,
)
with gr.Column():
with gr.Row():
with gr.Column():
gr.Markdown("#### Step 2: Press 'Generate' to get the result")
output = gr.Image(
label="Result",
elem_classes=["image-container", "hide-buttons"],
interactive=False
)
with gr.Row():
submit_button = gr.Button("Generate")
submit_button.click(
fn=generate_image,
inputs=[main_image_input, current_did],
outputs=[output, current_did],
concurrency_limit=None
)
with gr.Row():
gr.LoginButton()
with gr.Column():
main_image_middleware = gr.Image(
image_mode='RGBA',
type="pil",
label="Edited Image",
elem_classes=["middleware"]
)
mask_image_middleware = gr.Image(
image_mode='RGBA',
type="pil",
elem_classes=["middleware"]
)
show_case = gr.Examples(
examples=[
["assets/examples/result_01_01.png", "assets/examples/result_01_02.png",
"assets/examples/result_01_03.png"],
["assets/examples/result_02_01.png", "assets/examples/result_02_02.png",
"assets/examples/result_02_03.png"],
["assets/examples/result_03_01.png", "assets/examples/result_03_02.png",
"assets/examples/result_03_03.png"]
],
inputs=[main_image_middleware, mask_image_middleware, output],
outputs=[main_image_input, main_image_input, output],
elem_id="example-res-images",
fn=process_show_case_image,
run_on_click=True,
)
WeShop.queue(api_open=False).launch(show_api=False)