File size: 9,193 Bytes
884e760
6c74fa1
2d8c11a
6c74fa1
6a1229b
 
 
2d8c11a
 
 
ca9ec31
2d8c11a
 
d4d1341
2d8c11a
 
 
 
 
 
 
c060c17
ff363f9
 
6a1229b
ebb5e31
2d8c11a
a08eb0d
2d8c11a
 
 
ebb5e31
 
 
 
 
2d8c11a
 
 
ebb5e31
2d8c11a
 
6c74fa1
 
ebb5e31
6c74fa1
ebb5e31
 
 
d4d1341
ebb5e31
 
 
 
 
 
 
 
a08eb0d
ebb5e31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c74fa1
d4d1341
 
6a1229b
ebb5e31
6a1229b
2d8c11a
6c74fa1
d4d1341
 
6c74fa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcbf369
47009ed
 
c0ddebd
 
 
 
47009ed
1318a4b
eaaba91
1318a4b
eaaba91
1318a4b
dcbf369
 
6c74fa1
 
 
 
 
ebb5e31
a08eb0d
ebb5e31
6c74fa1
c2c9be0
02c3073
 
fd27637
02c3073
 
0e5b8cf
c3bac7e
 
ebb5e31
6c74fa1
 
4951f1d
ebb5e31
 
 
d4d1341
 
 
 
4951f1d
ebb5e31
 
24c81de
ebb5e31
 
 
 
 
d4d1341
 
 
 
 
 
ebb5e31
 
62aa379
ebb5e31
 
 
d4d1341
7678535
ebb5e31
 
24c81de
ebb5e31
 
 
 
 
 
6c74fa1
 
 
 
 
 
 
 
 
 
45299c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import gradio as gr
import numpy as np

import os
from PIL import Image
import requests
from io import BytesIO
import io
import base64

hf_token = os.environ.get("HF_TOKEN_API_DEMO") # we get it from a secret env variable, such that it's private
auth_headers = {"api_token": hf_token}

def convert_image_to_base64_string(mask_image):
    buffer = io.BytesIO()
    mask_image.save(buffer, format="PNG")  # You can choose the format (e.g., "JPEG", "PNG")
    # Encode the buffer in base64
    image_base64_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
    return f",{image_base64_string}" # for some reason the funciton which downloads image from base64 expects prefix of "," which is redundant in the url

def download_image(url):
    response = requests.get(url)
    img_bytes = BytesIO(response.content)
    return Image.open(img_bytes).convert("RGB")

def lifestyle_shot_by_text_api_call(image_base64_file, prompt):

    url = "http://engine.prod.bria-api.com/v1/product/lifestyle_shot_by_text"
    
    payload = {
    "file": image_base64_file,
    "scene_description": prompt,
    "num_results": 1,
    "sync": True,
    "original_quality": True,
    "optimize_description": True,
    }
    response = requests.post(url, json=payload, headers=auth_headers)
    response = response.json()
    res_image = download_image(response['result'][0][0])
    
    return res_image


def predict_ref_by_text(input_image, prompt):

    # init_image = Image.fromarray(dict['background'][:, :, :3], 'RGB') #dict['background'].convert("RGB")#.resize((1024, 1024))
    # mask = Image.fromarray(dict['layers'][0][:,:,3], 'L') #dict['layers'].convert("RGB")#.resize((1024, 1024))
    
    image_base64_file = convert_image_to_base64_string(input_image)
    
    gen_img = lifestyle_shot_by_text_api_call(image_base64_file, prompt)
    
    return gen_img


def lifestyle_shot_by_image_api_call(image_base64_file, ref_image_base64_file):

    url = "http://engine.prod.bria-api.com/v1/product/lifestyle_shot_by_image"
    
    payload = {
    "file": image_base64_file,
    "ref_image_file": ref_image_base64_file,
    "num_results": 1,
    "sync": True,
    "original_quality": True,
    "optimize_description": True,
    }
    response = requests.post(url, json=payload, headers=auth_headers)
    response = response.json()
    res_image = download_image(response['result'][0][0])
    
    return res_image


def predict_ref_by_image(init_image, ref_image):
    
    image_base64_file = convert_image_to_base64_string(init_image)
    ref_base64_file = convert_image_to_base64_string(ref_image)
    
    gen_img = lifestyle_shot_by_image_api_call(image_base64_file, ref_base64_file)
    
    return gen_img

def on_change_prompt(img: Image.Image | None, prompt: str | None):
    return gr.update(interactive=bool(img and prompt))

css = '''
.gradio-container{max-width: 1100px !important}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
    from {
        transform: rotate(0deg);
    }
    to {
        transform: rotate(360deg);
    }
}
#share-btn-container {padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;}
div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
#share-btn-container:hover {background-color: #060606}
#share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;}
#share-btn * {all: unset}
#share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
#share-btn-container .wrap {display: none !important}
#share-btn-container.hidden {display: none!important}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button {
    width: 100%;
    height: 50px;  /* Set a fixed height for the button */
    display: flex;
    align-items: center;
    justify-content: center;
}
#output-img img, #image_upload img {
    object-fit: contain; /* Ensure aspect ratio is preserved */
    width: 100%;
    height: auto; /* Let height adjust automatically */
}
#prompt-container{margin-top:-18px;}
#prompt-container .form{border-top-left-radius: 0;border-top-right-radius: 0}
#image_upload{border-bottom-left-radius: 0px;border-bottom-right-radius: 0px}
'''

image_blocks = gr.Blocks(css=css, elem_id="total-container")
with image_blocks as demo:
    # with gr.Column(elem_id="col-container"):
    gr.Markdown("## Product Shot Generation")
    gr.HTML('''
          <p style="margin-bottom: 10px; font-size: 94%">
            This demo showcases the <strong>Lifestyle Product Shot by Text</strong> and <strong>Lifestyle Product Shot by Image</strong> feature, enabling users to generate product backgrounds effortlessly.<br>
            With <strong>Lifestyle Product Shot by Text</strong>, users can create backgrounds using descriptive textual prompts, 
            while <strong>Lifestyle Product Shot by Image</strong> allows backgrounds to be generated based on a reference image for inspiration.<br>
            The pipeline comprises multiple components, including <a href="https://huggingface.co/briaai/BRIA-2.3" target="_blank">briaai/BRIA-2.3</a>, 
            <a href="https://huggingface.co/briaai/RMBG-2.0" target="_blank">briaai/RMBG-2.0</a>, <a href="https://huggingface.co/briaai/BRIA-2.3-ControlNet-BG-Gen" target="_blank">briaai/BRIA-2.3-ControlNet-BG-Gen</a> and 
            <a href="https://huggingface.co/briaai/Image-Prompt" target="_blank">briaai/Image-Prompt</a>, all trained on licensed data.<br>
            This ensures full legal liability coverage for copyright and privacy infringement.<br>
            Notes:<br>
            - High-resolution images may take longer to process.<br>
            - For best results in reference by image: make sure the foreground in the image is already located in the wanted position and scale, relative to the elements in the reference image.<br>
          </p>
        ''')
    with gr.Tab(label="By scene description", id="tab_prompt"):

        with gr.Row():
            with gr.Column():
                # image = gr.ImageEditor(sources=["upload"], layers=False, transforms=[], 
                                    # brush=gr.Brush(colors=["#000000"], color_mode="fixed"),
                                    # )
                image = gr.Image(type="pil", label="Input")
                prompt = gr.Textbox(label="scene description", placeholder="Enter your scene description here...")
                with gr.Row(elem_id="prompt-container", equal_height=True):
                    with gr.Column():
                        btn = gr.Button("Generate Product Shot!", elem_id="run_button")
            
            with gr.Column():
                image_out = gr.Image(label="Output", elem_id="output-img")

        # Button click will trigger the inpainting function (now with prompt included)
        for inp in [image, prompt]:
            inp.change(
                fn=on_change_prompt,
                inputs=[image, prompt],
                outputs=[btn],
            )
        btn.click(fn=predict_ref_by_text, inputs=[image, prompt], outputs=[image_out], api_name='run')
    
    with gr.Tab(label="By reference image", id="tab_ref_image"):

        with gr.Row():
            with gr.Column():
                image = gr.Image(type="pil", label="Input")
                ref_image = gr.Image(type="pil", label="Reference Image")
                with gr.Row(elem_id="prompt-container", equal_height=True):
                    with gr.Column():
                        btn = gr.Button("Generate Product Shot!", elem_id="run_button")
            
            with gr.Column():
                image_out = gr.Image(label="Output", elem_id="output-img")

        # Button click will trigger the inpainting function (now with prompt included)
        btn.click(fn=predict_ref_by_image, inputs=[image, ref_image], outputs=[image_out], api_name='run')

    gr.HTML(
        """
            <div class="footer">
                <p>Model by <a href="https://huggingface.co/diffusers" style="text-decoration: underline;" target="_blank">Diffusers</a> - Gradio Demo by 🤗 Hugging Face
                </p>
            </div>
        """
    )

image_blocks.queue(max_size=25, api_open=False).launch(show_api=False)