File size: 3,630 Bytes
2b799f2
 
 
 
 
1479d8c
2b799f2
 
a2cc10f
e6f2d3f
74c4534
2b799f2
74c4534
2b799f2
86bb1a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b7222
 
86bb1a1
79b7222
 
fd70c6f
79b7222
 
 
 
86bb1a1
79b7222
86bb1a1
79b7222
 
86bb1a1
 
797572d
74c4534
86bb1a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b799f2
 
1479d8c
2b799f2
 
 
 
 
1479d8c
 
bf83e6e
2b799f2
797572d
a829665
2b799f2
74c4534
2b799f2
1479d8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import json
import base64
import time
import gradio as gr
from PIL import Image
from io import BytesIO
import os

host = "http://18.119.36.46:8888"

def image_prompt(prompt, image1, image2, image3, image4):
    session = requests.Session()
    retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
    session.mount('http://', HTTPAdapter(max_retries=retries))

    # Read and encode images
    image_paths = [image1, image2, image3, image4]
    image_data = [
        {
            "cn_img": base64.b64encode(open(image_path, "rb").read()).decode('utf-8'),
            "cn_stop": 1,
            "cn_weight": 1,
            "cn_type": "ImagePrompt"
        } for image_path in image_paths if image_path
    ]

    params = {
        "prompt": prompt,
        "image_prompts": image_data,
        "async_process": True
    }

    response = session.post(
        url=f"{host}/v2/generation/text-to-image-with-ip",
        data=json.dumps(params),
        headers={"Content-Type": "application/json"},
        timeout=10
    )

    result = response.json()
    job_id = result.get('job_id')

    if not job_id:
        return None, "Job ID not found."

    # Polling for job status
    start_time = time.time()
    max_wait_time = 300  # 5 minutes max wait time
    while time.time() - start_time < max_wait_time:
        query_url = f"{host}/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
        response = session.get(query_url, timeout=10)
        job_data = response.json()

        job_stage = job_data.get("job_stage")
        job_step_preview = job_data.get("job_step_preview")
        job_result = job_data.get("job_result")

        # If there is a step preview, display it
        if job_step_preview:
            step_image = Image.open(BytesIO(base64.b64decode(job_step_preview)))
            return step_image, "Processing..."  # Update the gr.Image widget with step preview

        # If the job is completed successfully, display the final image
        if job_stage == "SUCCESS":
            final_image_url = job_result[0].get("url")
            if final_image_url:
                final_image_url = final_image_url.replace("127.0.0.1", "18.119.36.46")
                image_response = session.get(final_image_url, timeout=10)
                final_image = Image.open(BytesIO(image_response.content))
                return final_image, "Job completed successfully."
            return None, "Final image URL not found in the job data."

        # If the job failed
        elif job_stage == "FAILED":
            return None, "Job failed."

        # If the job is still running, continue polling
        time.sleep(2)

    return None, "Job timed out."

def gradio_app():
    with gr.Blocks() as demo:
        prompt = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
        with gr.Row():
            image1 = gr.Image(label="Image Prompt 1", type="filepath")
            image2 = gr.Image(label="Image Prompt 2", type="filepath")
            image3 = gr.Image(label="Image Prompt 3", type="filepath")
            image4 = gr.Image(label="Image Prompt 4", type="filepath")
        output_image = gr.Image(label="Generated Image")
        status = gr.Textbox(label="Status")
        
        generate_button = gr.Button("Generate Image")
        generate_button.click(image_prompt, inputs=[prompt, image1, image2, image3, image4], outputs=[output_image, status])
        
    demo.launch()

if __name__ == "__main__":
    gradio_app()