File size: 4,208 Bytes
2b799f2
 
 
 
 
1479d8c
2b799f2
 
a2cc10f
e6f2d3f
74c4534
2b799f2
74c4534
2b799f2
79b7222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd70c6f
79b7222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d9ade7
79b7222
 
 
 
3d9ade7
79b7222
 
 
 
 
 
 
3d9ade7
79b7222
 
 
 
 
 
3d9ade7
79b7222
 
 
74c4534
2b799f2
 
1479d8c
2b799f2
 
 
 
 
1479d8c
 
bf83e6e
2b799f2
3d9ade7
a829665
2b799f2
74c4534
2b799f2
1479d8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import json
import base64
import time
import gradio as gr
from PIL import Image
from io import BytesIO
import os

host = "http://18.119.36.46:8888"

def image_prompt(prompt, image1, image2, image3, image4):
    source1 = open(image1, "rb").read()
    source2 = open(image2, "rb").read()
    source3 = open(image3, "rb").read()
    source4 = open(image4, "rb").read()
    
    params = {
        "prompt": prompt,
        "image_prompts": [
            {
                "cn_img": base64.b64encode(source1).decode('utf-8'),
                "cn_stop": 1,
                "cn_weight": 1,
                "cn_type": "ImagePrompt"
            },{
                "cn_img": base64.b64encode(source2).decode('utf-8'),
                "cn_stop": 1,
                "cn_weight": 1,
                "cn_type": "ImagePrompt"
            },{
                "cn_img": base64.b64encode(source3).decode('utf-8'),
                "cn_stop": 1,
                "cn_weight": 1,
                "cn_type": "ImagePrompt"
            },{
                "cn_img": base64.b64encode(source4).decode('utf-8'),
                "cn_stop": 1,
                "cn_weight": 1,
                "cn_type": "ImagePrompt"
            }
        ],
        "async_process": True
    }
    
    session = requests.Session()
    retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
    session.mount('http://', HTTPAdapter(max_retries=retries))

    response = session.post(
        url=f"{host}/v2/generation/text-to-image-with-ip",
        data=json.dumps(params),
        headers={"Content-Type": "application/json"},
        timeout=10  # Increase timeout as needed
    )
    result = response.json()
    
    job_id = result.get('job_id')
    if job_id:
        while True:
            query_url = f"{host}/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
            response = session.get(query_url, timeout=10)  # Increase timeout as needed
            job_data = response.json()
            
            job_stage = job_data.get("job_stage")
            job_step_preview = job_data.get("job_step_preview")
            job_result = job_data.get("job_result")

            # Update image and status in real-time
            if job_stage == "RUNNING" and job_step_preview:
                image = Image.open(BytesIO(base64.b64decode(job_step_preview)))
                yield image, f"Job is running. Stage: {job_stage}"

            elif job_stage == "SUCCESS":
                final_image_url = job_result[0].get("url")
                if final_image_url:
                    final_image_url = final_image_url.replace("127.0.0.1", "18.119.36.46")
                    image_response = session.get(final_image_url, timeout=10)  # Increase timeout as needed
                    image = Image.open(BytesIO(image_response.content))
                    yield image, "Job completed successfully."
                    break
                else:
                    yield None, "Final image URL not found in the job data."
                    break
            elif job_stage == "FAILED":
                yield None, "Job failed."
                break

            time.sleep(2)  # Wait 2 seconds before the next update
    else:
        yield None, "Job ID not found."

def gradio_app():
    with gr.Blocks() as demo:
        prompt = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
        with gr.Row():
            image1 = gr.Image(label="Image Prompt 1", type="filepath")
            image2 = gr.Image(label="Image Prompt 2", type="filepath")
            image3 = gr.Image(label="Image Prompt 3", type="filepath")
            image4 = gr.Image(label="Image Prompt 4", type="filepath")
        output_image = gr.Image(label="Generated Image")
        status = gr.Textbox(label="Status")
        
        generate_button = gr.Button("Generate Image")
        generate_button.click(fn=image_prompt, inputs=[prompt, image1, image2, image3, image4], outputs=[output_image, status], stream=True)
        
    demo.launch()

if __name__ == "__main__":
    gradio_app()