File size: 6,857 Bytes
a829665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b799f2
 
 
 
 
 
 
a2cc10f
fd70c6f
74c4534
a829665
2b799f2
74c4534
2b799f2
a829665
fd70c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74c4534
fd70c6f
 
 
 
a829665
fd70c6f
a829665
fd70c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
bf83e6e
 
a829665
fd70c6f
 
 
 
 
 
 
a829665
fd70c6f
 
 
 
 
 
 
 
 
a829665
fd70c6f
 
 
 
 
 
 
 
 
 
 
 
bf83e6e
fd70c6f
 
 
 
 
 
 
 
74c4534
2b799f2
a829665
2b799f2
a829665
 
 
2b799f2
 
 
 
 
bf83e6e
2b799f2
a829665
 
 
 
 
2b799f2
bf83e6e
 
 
 
 
 
a829665
bf83e6e
 
2b799f2
74c4534
2b799f2
fd70c6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
{
    "inputs": {
        "prompt": "A text prompt to generate the image from.",
        "image1": "Filepath for the first image prompt.",
        "image2": "Filepath for the second image prompt.",
        "image3": "Filepath for the third image prompt.",
        "image4": "Filepath for the fourth image prompt."
    },
    "outputs": {
        "output_image": "The final generated image or step preview.",
        "status": "Current status of the image generation process (e.g., Running, Success, Failed)."
    },
    "description": "This Gradio app takes a text prompt and up to four image prompts to generate an AI-powered image. It provides real-time updates on the job's progress and displays the final image once complete."
}
"""

import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import json
import base64
import gradio as gr
from PIL import Image
from io import BytesIO
import os

# Host server configuration
host = "http://18.119.36.46:8888"

def image_prompt(prompt, image1, image2, image3, image4):
    """Handles the image generation process by sending the text and image prompts to the server."""
    try:
        with open(image1, "rb") as img1, open(image2, "rb") as img2, open(image3, "rb") as img3, open(image4, "rb") as img4:
            source1 = img1.read()
            source2 = img2.read()
            source3 = img3.read()
            source4 = img4.read()
        
        params = {
            "prompt": prompt,
            "image_prompts": [
                {
                    "cn_img": base64.b64encode(source1).decode('utf-8'),
                    "cn_stop": 1,
                    "cn_weight": 1,
                    "cn_type": "ImagePrompt"
                },{
                    "cn_img": base64.b64encode(source2).decode('utf-8'),
                    "cn_stop": 1,
                    "cn_weight": 1,
                    "cn_type": "ImagePrompt"
                },{
                    "cn_img": base64.b64encode(source3).decode('utf-8'),
                    "cn_stop": 1,
                    "cn_weight": 1,
                    "cn_type": "ImagePrompt"
                },{
                    "cn_img": base64.b64encode(source4).decode('utf-8'),
                    "cn_stop": 1,
                    "cn_weight": 1,
                    "cn_type": "ImagePrompt"
                }
            ],
            "async_process": True
        }
        
        session = requests.Session()
        retries = Retry(total=5, backoff_factor=2, status_forcelist=[502, 503, 504])
        session.mount('http://', HTTPAdapter(max_retries=retries))

        response = session.post(
            url=f"{host}/v2/generation/text-to-image-with-ip",
            data=json.dumps(params),
            headers={"Content-Type": "application/json"},
            timeout=15
        )
        response.raise_for_status()  # Ensure we raise errors for bad HTTP status codes
        result = response.json()
        
        job_id = result.get('job_id')
        if not job_id:
            raise ValueError("Job ID not found in response")
        
        return job_id

    except requests.exceptions.RequestException as e:
        return None, f"API request failed: {str(e)}"
    except json.JSONDecodeError:
        return None, "Failed to decode JSON response from the server."
    except Exception as e:
        return None, f"An unexpected error occurred: {str(e)}"

def query_status(job_id):
    """Polls the server to get the current status of the image generation job."""
    try:
        session = requests.Session()
        retries = Retry(total=5, backoff_factor=2, status_forcelist=[502, 503, 504])
        session.mount('http://', HTTPAdapter(max_retries=retries))
        
        query_url = f"http://18.119.36.46:8888/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
        response = session.get(query_url, timeout=15)
        response.raise_for_status()
        job_data = response.json()
        
        job_stage = job_data.get("job_stage")
        
        if job_stage == "SUCCESS":
            final_image_url = job_data.get("job_result")[0].get("url")
            if final_image_url:
                final_image_url = final_image_url.replace("127.0.0.1", "18.119.36.46")
                image_response = session.get(final_image_url, timeout=15)
                image_response.raise_for_status()
                image = Image.open(BytesIO(image_response.content))
                return image, "Job completed successfully."
            else:
                return None, "Final image URL not found in the job data."
        elif job_stage == "RUNNING":
            step_preview_base64 = job_data.get("job_step_preview")
            if step_preview_base64:
                image = Image.open(BytesIO(base64.b64decode(step_preview_base64)))
                return image, "Job is running, step preview available."
            return None, "Job is running, no step preview available."
        elif job_stage == "FAILED":
            return None, "Job failed."
        else:
            return None, "Unknown job stage."

    except requests.exceptions.RequestException as e:
        return None, f"API request failed: {str(e)}"
    except json.JSONDecodeError:
        return None, "Failed to decode JSON response from the server."
    except Exception as e:
        return None, f"An unexpected error occurred: {str(e)}"

def gradio_app():
    """Defines the Gradio app layout and functionality."""
    with gr.Blocks() as demo:
        gr.Markdown("# AI Image Generator\nEnter a text prompt and upload up to 4 images to generate a unique AI-powered image.")
        
        prompt = gr.Textbox(label="Text Prompt", placeholder="Describe the image you want to generate")
        with gr.Row():
            image1 = gr.Image(label="Image Prompt 1", type="filepath")
            image2 = gr.Image(label="Image Prompt 2", type="filepath")
            image3 = gr.Image(label="Image Prompt 3", type="filepath")
            image4 = gr.Image(label="Image Prompt 4", type="filepath")
        
        generate_button = gr.Button("Generate Image")
        
        output_image = gr.Image(label="Generated Image", visible=False)
        status = gr.Textbox(label="Status", value="Awaiting input...", interactive=False)
        
        job_id = gr.State()
        
        def update_image(job_id):
            if job_id:
                image, status_text = query_status(job_id)
                return image, status_text
            return None, "No job ID found."
        
        generate_button.click(image_prompt, inputs=[prompt, image1, image2, image3, image4], outputs=[job_id])
        demo.load(update_image, inputs=job_id, outputs=[output_image, status], every=2)
    
    demo.launch()

if __name__ == "__main__":
    gradio_app()