rayochoajr's picture
Update app.py
a829665 verified
raw
history blame
6.86 kB
"""
{
"inputs": {
"prompt": "A text prompt to generate the image from.",
"image1": "Filepath for the first image prompt.",
"image2": "Filepath for the second image prompt.",
"image3": "Filepath for the third image prompt.",
"image4": "Filepath for the fourth image prompt."
},
"outputs": {
"output_image": "The final generated image or step preview.",
"status": "Current status of the image generation process (e.g., Running, Success, Failed)."
},
"description": "This Gradio app takes a text prompt and up to four image prompts to generate an AI-powered image. It provides real-time updates on the job's progress and displays the final image once complete."
}
"""
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import json
import base64
import gradio as gr
from PIL import Image
from io import BytesIO
import os
# Host server configuration
host = "http://18.119.36.46:8888"
def image_prompt(prompt, image1, image2, image3, image4):
"""Handles the image generation process by sending the text and image prompts to the server."""
try:
with open(image1, "rb") as img1, open(image2, "rb") as img2, open(image3, "rb") as img3, open(image4, "rb") as img4:
source1 = img1.read()
source2 = img2.read()
source3 = img3.read()
source4 = img4.read()
params = {
"prompt": prompt,
"image_prompts": [
{
"cn_img": base64.b64encode(source1).decode('utf-8'),
"cn_stop": 1,
"cn_weight": 1,
"cn_type": "ImagePrompt"
},{
"cn_img": base64.b64encode(source2).decode('utf-8'),
"cn_stop": 1,
"cn_weight": 1,
"cn_type": "ImagePrompt"
},{
"cn_img": base64.b64encode(source3).decode('utf-8'),
"cn_stop": 1,
"cn_weight": 1,
"cn_type": "ImagePrompt"
},{
"cn_img": base64.b64encode(source4).decode('utf-8'),
"cn_stop": 1,
"cn_weight": 1,
"cn_type": "ImagePrompt"
}
],
"async_process": True
}
session = requests.Session()
retries = Retry(total=5, backoff_factor=2, status_forcelist=[502, 503, 504])
session.mount('http://', HTTPAdapter(max_retries=retries))
response = session.post(
url=f"{host}/v2/generation/text-to-image-with-ip",
data=json.dumps(params),
headers={"Content-Type": "application/json"},
timeout=15
)
response.raise_for_status() # Ensure we raise errors for bad HTTP status codes
result = response.json()
job_id = result.get('job_id')
if not job_id:
raise ValueError("Job ID not found in response")
return job_id
except requests.exceptions.RequestException as e:
return None, f"API request failed: {str(e)}"
except json.JSONDecodeError:
return None, "Failed to decode JSON response from the server."
except Exception as e:
return None, f"An unexpected error occurred: {str(e)}"
def query_status(job_id):
"""Polls the server to get the current status of the image generation job."""
try:
session = requests.Session()
retries = Retry(total=5, backoff_factor=2, status_forcelist=[502, 503, 504])
session.mount('http://', HTTPAdapter(max_retries=retries))
query_url = f"http://18.119.36.46:8888/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
response = session.get(query_url, timeout=15)
response.raise_for_status()
job_data = response.json()
job_stage = job_data.get("job_stage")
if job_stage == "SUCCESS":
final_image_url = job_data.get("job_result")[0].get("url")
if final_image_url:
final_image_url = final_image_url.replace("127.0.0.1", "18.119.36.46")
image_response = session.get(final_image_url, timeout=15)
image_response.raise_for_status()
image = Image.open(BytesIO(image_response.content))
return image, "Job completed successfully."
else:
return None, "Final image URL not found in the job data."
elif job_stage == "RUNNING":
step_preview_base64 = job_data.get("job_step_preview")
if step_preview_base64:
image = Image.open(BytesIO(base64.b64decode(step_preview_base64)))
return image, "Job is running, step preview available."
return None, "Job is running, no step preview available."
elif job_stage == "FAILED":
return None, "Job failed."
else:
return None, "Unknown job stage."
except requests.exceptions.RequestException as e:
return None, f"API request failed: {str(e)}"
except json.JSONDecodeError:
return None, "Failed to decode JSON response from the server."
except Exception as e:
return None, f"An unexpected error occurred: {str(e)}"
def gradio_app():
"""Defines the Gradio app layout and functionality."""
with gr.Blocks() as demo:
gr.Markdown("# AI Image Generator\nEnter a text prompt and upload up to 4 images to generate a unique AI-powered image.")
prompt = gr.Textbox(label="Text Prompt", placeholder="Describe the image you want to generate")
with gr.Row():
image1 = gr.Image(label="Image Prompt 1", type="filepath")
image2 = gr.Image(label="Image Prompt 2", type="filepath")
image3 = gr.Image(label="Image Prompt 3", type="filepath")
image4 = gr.Image(label="Image Prompt 4", type="filepath")
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image", visible=False)
status = gr.Textbox(label="Status", value="Awaiting input...", interactive=False)
job_id = gr.State()
def update_image(job_id):
if job_id:
image, status_text = query_status(job_id)
return image, status_text
return None, "No job ID found."
generate_button.click(image_prompt, inputs=[prompt, image1, image2, image3, image4], outputs=[job_id])
demo.load(update_image, inputs=job_id, outputs=[output_image, status], every=2)
demo.launch()
if __name__ == "__main__":
gradio_app()