rayochoajr's picture
Update app.py
79b7222 verified
raw
history blame
4.21 kB
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import json
import base64
import time
import gradio as gr
from PIL import Image
from io import BytesIO
import os
host = "http://18.119.36.46:8888"
def image_prompt(prompt, image1, image2, image3, image4):
source1 = open(image1, "rb").read()
source2 = open(image2, "rb").read()
source3 = open(image3, "rb").read()
source4 = open(image4, "rb").read()
params = {
"prompt": prompt,
"image_prompts": [
{
"cn_img": base64.b64encode(source1).decode('utf-8'),
"cn_stop": 1,
"cn_weight": 1,
"cn_type": "ImagePrompt"
},{
"cn_img": base64.b64encode(source2).decode('utf-8'),
"cn_stop": 1,
"cn_weight": 1,
"cn_type": "ImagePrompt"
},{
"cn_img": base64.b64encode(source3).decode('utf-8'),
"cn_stop": 1,
"cn_weight": 1,
"cn_type": "ImagePrompt"
},{
"cn_img": base64.b64encode(source4).decode('utf-8'),
"cn_stop": 1,
"cn_weight": 1,
"cn_type": "ImagePrompt"
}
],
"async_process": True
}
session = requests.Session()
retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
session.mount('http://', HTTPAdapter(max_retries=retries))
response = session.post(
url=f"{host}/v2/generation/text-to-image-with-ip",
data=json.dumps(params),
headers={"Content-Type": "application/json"},
timeout=10 # Increase timeout as needed
)
result = response.json()
job_id = result.get('job_id')
if job_id:
while True:
query_url = f"{host}/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
response = session.get(query_url, timeout=10) # Increase timeout as needed
job_data = response.json()
job_stage = job_data.get("job_stage")
job_step_preview = job_data.get("job_step_preview")
job_result = job_data.get("job_result")
# Update image and status in real-time
if job_stage == "RUNNING" and job_step_preview:
image = Image.open(BytesIO(base64.b64decode(job_step_preview)))
yield image, f"Job is running. Stage: {job_stage}"
elif job_stage == "SUCCESS":
final_image_url = job_result[0].get("url")
if final_image_url:
final_image_url = final_image_url.replace("127.0.0.1", "18.119.36.46")
image_response = session.get(final_image_url, timeout=10) # Increase timeout as needed
image = Image.open(BytesIO(image_response.content))
yield image, "Job completed successfully."
break
else:
yield None, "Final image URL not found in the job data."
break
elif job_stage == "FAILED":
yield None, "Job failed."
break
time.sleep(2) # Wait 2 seconds before the next update
else:
yield None, "Job ID not found."
def gradio_app():
with gr.Blocks() as demo:
prompt = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
with gr.Row():
image1 = gr.Image(label="Image Prompt 1", type="filepath")
image2 = gr.Image(label="Image Prompt 2", type="filepath")
image3 = gr.Image(label="Image Prompt 3", type="filepath")
image4 = gr.Image(label="Image Prompt 4", type="filepath")
output_image = gr.Image(label="Generated Image")
status = gr.Textbox(label="Status")
generate_button = gr.Button("Generate Image")
generate_button.click(fn=image_prompt, inputs=[prompt, image1, image2, image3, image4], outputs=[output_image, status], stream=True)
demo.launch()
if __name__ == "__main__":
gradio_app()