File size: 4,017 Bytes
23ed1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import cv2
import numpy as np
import requests
from requests_toolbelt.multipart.encoder import MultipartEncoder
import logging
import json
from io import BytesIO
from dataclasses import dataclass


@dataclass
class TryOnDiffusionAPIResponse:
    status_code: int
    image: np.ndarray = None
    response_data: bytes = None
    error_details: str = None
    seed: int = None


class TryOnDiffusionClient:
    def __init__(self, base_url: str = "http://localhost:8000/", api_key: str = ""):
        self._logger = logging.getLogger("try_on_diffusion_client")
        self._base_url = base_url
        self._api_key = api_key

        if self._base_url[-1] == "/":
            self._base_url = self._base_url[:-1]

    @staticmethod
    def _image_to_upload_file(image: np.ndarray) -> tuple:
        _, jpeg_data = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 99])
        jpeg_data = jpeg_data.tobytes()

        fp = BytesIO(jpeg_data)

        return "image.jpg", fp, "image/jpeg"

    def try_on_file(
        self,
        clothing_image: np.ndarray = None,
        clothing_prompt: str = None,
        avatar_image: np.ndarray = None,
        avatar_prompt: str = None,
        avatar_sex: str = None,
        background_image: np.ndarray = None,
        background_prompt: str = None,
        negative_prompt: str = None,
        num_images: int = 1,
        seed: int = -1,
        raw_response: bool = False,
    ) -> TryOnDiffusionAPIResponse:
        url = self._base_url + "/try-on-file"

        request_data = {"num_images": str(num_images), "seed": str(seed)}

        if clothing_image is not None:
            request_data["clothing_image"] = self._image_to_upload_file(clothing_image)

        if clothing_prompt is not None:
            request_data["clothing_prompt"] = clothing_prompt

        if avatar_image is not None:
            request_data["avatar_image"] = self._image_to_upload_file(avatar_image)

        if avatar_prompt is not None:
            request_data["avatar_prompt"] = avatar_prompt

        if avatar_sex is not None:
            request_data["avatar_sex"] = avatar_sex

        if background_image is not None:
            request_data["background_image"] = self._image_to_upload_file(background_image)

        if background_prompt is not None:
            request_data["background_prompt"] = background_prompt

        if negative_prompt is not None:
            request_data["negative_prompt"] = negative_prompt

        multipart_data = MultipartEncoder(fields=request_data)

        try:
            response = requests.post(
                url,
                data=multipart_data,
                headers={"Content-Type": multipart_data.content_type, "X-API-Key": self._api_key},
            )
        except Exception as e:
            self._logger.error(e, exc_info=True)
            return TryOnDiffusionAPIResponse(status_code=0)

        if response.status_code != 200:
            self._logger.warning(f"Request failed, status code: {response.status_code}, response: {response.content}")

        result = TryOnDiffusionAPIResponse(status_code=response.status_code)

        if not raw_response and response.status_code == 200:
            try:
                result.image = cv2.imdecode(np.frombuffer(response.content, np.uint8), cv2.IMREAD_COLOR)
            except:
                result.image = None
        else:
            result.response_data = response.content

        if result.status_code == 200:
            if "X-Seed" in response.headers:
                result.seed = int(response.headers["X-Seed"])
        else:
            try:
                response_json = (
                    json.loads(result.response_data.decode("utf-8")) if result.response_data is not None else None
                )

                if response_json is not None and "detail" in response_json:
                    result.error_details = response_json["detail"]
            except:
                result.error_details = None

        return result