File size: 4,006 Bytes
2b799f2
 
 
 
 
1479d8c
2b799f2
 
a2cc10f
74c4534
2b799f2
74c4534
2b799f2
3d9ade7
 
 
fd70c6f
3d9ade7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74c4534
2b799f2
 
1479d8c
2b799f2
 
 
 
 
1479d8c
 
bf83e6e
2b799f2
3d9ade7
a829665
2b799f2
74c4534
2b799f2
1479d8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import json
import base64
import time
import gradio as gr
from PIL import Image
from io import BytesIO

host = "http://18.119.36.46:8888"

def image_prompt(prompt, image1, image2, image3, image4):
    try:
        image_sources = [open(image, "rb").read() for image in [image1, image2, image3, image4]]
        encoded_images = [base64.b64encode(img).decode('utf-8') for img in image_sources]

        params = {
            "prompt": prompt,
            "image_prompts": [
                {
                    "cn_img": encoded_img,
                    "cn_stop": 1,
                    "cn_weight": 1,
                    "cn_type": "ImagePrompt"
                } for encoded_img in encoded_images
            ],
            "async_process": True
        }
        
        session = requests.Session()
        retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
        session.mount('http://', HTTPAdapter(max_retries=retries))

        response = session.post(
            url=f"{host}/v2/generation/text-to-image-with-ip",
            data=json.dumps(params),
            headers={"Content-Type": "application/json"},
            timeout=10
        )
        response.raise_for_status()
        result = response.json()
        job_id = result.get('job_id')

        if job_id:
            while True:
                query_url = f"{host}/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
                response = session.get(query_url, timeout=10)
                response.raise_for_status()
                job_data = response.json()

                job_stage = job_data.get("job_stage")
                job_step_preview = job_data.get("job_step_preview")
                job_result = job_data.get("job_result")

                if job_stage == "RUNNING" and job_step_preview:
                    image = Image.open(BytesIO(base64.b64decode(job_step_preview)))
                    yield image, f"Job is running smoothly. Current stage: {job_stage}. Hang tight!"

                elif job_stage == "SUCCESS":
                    final_image_url = job_result[0].get("url")
                    if final_image_url:
                        final_image_url = final_image_url.replace("127.0.0.1", "18.119.36.46")
                        image_response = session.get(final_image_url, timeout=10)
                        image = Image.open(BytesIO(image_response.content))
                        yield image, "Job completed successfully. Enjoy your masterpiece!"
                        break
                    else:
                        yield None, "Final image URL not found. Something went amiss."
                        break

                elif job_stage == "FAILED":
                    yield None, "Job failed. Let's check the parameters and try again."
                    break

                time.sleep(2)

        else:
            yield None, "Job ID not found. Did we miss something in the setup?"

    except Exception as e:
        yield None, f"An error occurred: {str(e)}. We'll need to debug this."

def gradio_app():
    with gr.Blocks() as demo:
        prompt = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
        with gr.Row():
            image1 = gr.Image(label="Image Prompt 1", type="filepath")
            image2 = gr.Image(label="Image Prompt 2", type="filepath")
            image3 = gr.Image(label="Image Prompt 3", type="filepath")
            image4 = gr.Image(label="Image Prompt 4", type="filepath")
        output_image = gr.Image(label="Generated Image")
        status = gr.Textbox(label="Status")
        
        generate_button = gr.Button("Generate Image")
        generate_button.click(fn=image_prompt, inputs=[prompt, image1, image2, image3, image4], outputs=[output_image, status], stream=True)
        
    demo.launch()

if __name__ == "__main__":
    gradio_app()