rayochoajr commited on
Commit
1479d8c
·
verified ·
1 Parent(s): a829665

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -139
app.py CHANGED
@@ -1,167 +1,106 @@
1
- """
2
- {
3
- "inputs": {
4
- "prompt": "A text prompt to generate the image from.",
5
- "image1": "Filepath for the first image prompt.",
6
- "image2": "Filepath for the second image prompt.",
7
- "image3": "Filepath for the third image prompt.",
8
- "image4": "Filepath for the fourth image prompt."
9
- },
10
- "outputs": {
11
- "output_image": "The final generated image or step preview.",
12
- "status": "Current status of the image generation process (e.g., Running, Success, Failed)."
13
- },
14
- "description": "This Gradio app takes a text prompt and up to four image prompts to generate an AI-powered image. It provides real-time updates on the job's progress and displays the final image once complete."
15
- }
16
- """
17
-
18
  import requests
19
  from requests.adapters import HTTPAdapter
20
  from requests.packages.urllib3.util.retry import Retry
21
  import json
22
  import base64
 
23
  import gradio as gr
24
  from PIL import Image
25
  from io import BytesIO
26
  import os
27
 
28
- # Host server configuration
29
  host = "http://18.119.36.46:8888"
30
 
31
  def image_prompt(prompt, image1, image2, image3, image4):
32
- """Handles the image generation process by sending the text and image prompts to the server."""
33
- try:
34
- with open(image1, "rb") as img1, open(image2, "rb") as img2, open(image3, "rb") as img3, open(image4, "rb") as img4:
35
- source1 = img1.read()
36
- source2 = img2.read()
37
- source3 = img3.read()
38
- source4 = img4.read()
39
-
40
- params = {
41
- "prompt": prompt,
42
- "image_prompts": [
43
- {
44
- "cn_img": base64.b64encode(source1).decode('utf-8'),
45
- "cn_stop": 1,
46
- "cn_weight": 1,
47
- "cn_type": "ImagePrompt"
48
- },{
49
- "cn_img": base64.b64encode(source2).decode('utf-8'),
50
- "cn_stop": 1,
51
- "cn_weight": 1,
52
- "cn_type": "ImagePrompt"
53
- },{
54
- "cn_img": base64.b64encode(source3).decode('utf-8'),
55
- "cn_stop": 1,
56
- "cn_weight": 1,
57
- "cn_type": "ImagePrompt"
58
- },{
59
- "cn_img": base64.b64encode(source4).decode('utf-8'),
60
- "cn_stop": 1,
61
- "cn_weight": 1,
62
- "cn_type": "ImagePrompt"
63
- }
64
- ],
65
- "async_process": True
66
- }
67
-
68
- session = requests.Session()
69
- retries = Retry(total=5, backoff_factor=2, status_forcelist=[502, 503, 504])
70
- session.mount('http://', HTTPAdapter(max_retries=retries))
71
-
72
- response = session.post(
73
- url=f"{host}/v2/generation/text-to-image-with-ip",
74
- data=json.dumps(params),
75
- headers={"Content-Type": "application/json"},
76
- timeout=15
77
- )
78
- response.raise_for_status() # Ensure we raise errors for bad HTTP status codes
79
- result = response.json()
80
-
81
- job_id = result.get('job_id')
82
- if not job_id:
83
- raise ValueError("Job ID not found in response")
84
-
85
- return job_id
86
-
87
- except requests.exceptions.RequestException as e:
88
- return None, f"API request failed: {str(e)}"
89
- except json.JSONDecodeError:
90
- return None, "Failed to decode JSON response from the server."
91
- except Exception as e:
92
- return None, f"An unexpected error occurred: {str(e)}"
93
-
94
- def query_status(job_id):
95
- """Polls the server to get the current status of the image generation job."""
96
- try:
97
- session = requests.Session()
98
- retries = Retry(total=5, backoff_factor=2, status_forcelist=[502, 503, 504])
99
- session.mount('http://', HTTPAdapter(max_retries=retries))
100
-
101
- query_url = f"http://18.119.36.46:8888/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
102
- response = session.get(query_url, timeout=15)
103
- response.raise_for_status()
104
- job_data = response.json()
105
-
106
- job_stage = job_data.get("job_stage")
107
-
108
- if job_stage == "SUCCESS":
109
- final_image_url = job_data.get("job_result")[0].get("url")
110
- if final_image_url:
111
- final_image_url = final_image_url.replace("127.0.0.1", "18.119.36.46")
112
- image_response = session.get(final_image_url, timeout=15)
113
- image_response.raise_for_status()
114
- image = Image.open(BytesIO(image_response.content))
115
- return image, "Job completed successfully."
116
- else:
117
- return None, "Final image URL not found in the job data."
118
- elif job_stage == "RUNNING":
119
- step_preview_base64 = job_data.get("job_step_preview")
120
- if step_preview_base64:
121
- image = Image.open(BytesIO(base64.b64decode(step_preview_base64)))
122
- return image, "Job is running, step preview available."
123
- return None, "Job is running, no step preview available."
124
- elif job_stage == "FAILED":
125
- return None, "Job failed."
126
- else:
127
- return None, "Unknown job stage."
128
 
129
- except requests.exceptions.RequestException as e:
130
- return None, f"API request failed: {str(e)}"
131
- except json.JSONDecodeError:
132
- return None, "Failed to decode JSON response from the server."
133
- except Exception as e:
134
- return None, f"An unexpected error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  def gradio_app():
137
- """Defines the Gradio app layout and functionality."""
138
  with gr.Blocks() as demo:
139
- gr.Markdown("# AI Image Generator\nEnter a text prompt and upload up to 4 images to generate a unique AI-powered image.")
140
-
141
- prompt = gr.Textbox(label="Text Prompt", placeholder="Describe the image you want to generate")
142
  with gr.Row():
143
  image1 = gr.Image(label="Image Prompt 1", type="filepath")
144
  image2 = gr.Image(label="Image Prompt 2", type="filepath")
145
  image3 = gr.Image(label="Image Prompt 3", type="filepath")
146
  image4 = gr.Image(label="Image Prompt 4", type="filepath")
 
 
147
 
148
  generate_button = gr.Button("Generate Image")
 
149
 
150
- output_image = gr.Image(label="Generated Image", visible=False)
151
- status = gr.Textbox(label="Status", value="Awaiting input...", interactive=False)
152
-
153
- job_id = gr.State()
154
-
155
- def update_image(job_id):
156
- if job_id:
157
- image, status_text = query_status(job_id)
158
- return image, status_text
159
- return None, "No job ID found."
160
-
161
- generate_button.click(image_prompt, inputs=[prompt, image1, image2, image3, image4], outputs=[job_id])
162
- demo.load(update_image, inputs=job_id, outputs=[output_image, status], every=2)
163
-
164
  demo.launch()
165
 
166
  if __name__ == "__main__":
167
- gradio_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  from requests.adapters import HTTPAdapter
3
  from requests.packages.urllib3.util.retry import Retry
4
  import json
5
  import base64
6
+ import time
7
  import gradio as gr
8
  from PIL import Image
9
  from io import BytesIO
10
  import os
11
 
 
12
  host = "http://18.119.36.46:8888"
13
 
14
  def image_prompt(prompt, image1, image2, image3, image4):
15
+ source1 = open(image1, "rb").read()
16
+ source2 = open(image2, "rb").read()
17
+ source3 = open(image3, "rb").read()
18
+ source4 = open(image4, "rb").read()
19
+
20
+ params = {
21
+ "prompt": prompt,
22
+ "image_prompts": [
23
+ {
24
+ "cn_img": base64.b64encode(source1).decode('utf-8'),
25
+ "cn_stop": 1,
26
+ "cn_weight": 1,
27
+ "cn_type": "ImagePrompt"
28
+ },{
29
+ "cn_img": base64.b64encode(source2).decode('utf-8'),
30
+ "cn_stop": 1,
31
+ "cn_weight": 1,
32
+ "cn_type": "ImagePrompt"
33
+ },{
34
+ "cn_img": base64.b64encode(source3).decode('utf-8'),
35
+ "cn_stop": 1,
36
+ "cn_weight": 1,
37
+ "cn_type": "ImagePrompt"
38
+ },{
39
+ "cn_img": base64.b64encode(source4).decode('utf-8'),
40
+ "cn_stop": 1,
41
+ "cn_weight": 1,
42
+ "cn_type": "ImagePrompt"
43
+ }
44
+ ],
45
+ "async_process": True
46
+ }
47
+
48
+ session = requests.Session()
49
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
50
+ session.mount('http://', HTTPAdapter(max_retries=retries))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ response = session.post(
53
+ url=f"{host}/v2/generation/text-to-image-with-ip",
54
+ data=json.dumps(params),
55
+ headers={"Content-Type": "application/json"},
56
+ timeout=10 # Increase timeout as needed
57
+ )
58
+ result = response.json()
59
+
60
+ job_id = result.get('job_id')
61
+ if job_id:
62
+ while True:
63
+ query_url = f"{host}/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
64
+ response = session.get(query_url, timeout=10) # Increase timeout as needed
65
+ job_data = response.json()
66
+
67
+ job_stage = job_data.get("job_stage")
68
+
69
+ if job_stage == "SUCCESS":
70
+ final_image_url = job_data.get("job_result")[0].get("url")
71
+ if final_image_url:
72
+ final_image_url = final_image_url.replace("127.0.0.1", "18.119.36.46")
73
+ image_response = session.get(final_image_url, timeout=10) # Increase timeout as needed
74
+ image = Image.open(BytesIO(image_response.content))
75
+ return image, "Job completed successfully."
76
+ else:
77
+ return None, "Final image URL not found in the job data."
78
+ elif job_stage == "RUNNING":
79
+ step_preview_base64 = job_data.get("job_step_preview")
80
+ if step_preview_base64:
81
+ image = Image.open(BytesIO(base64.b64decode(step_preview_base64)))
82
+ time.sleep(5)
83
+ continue
84
+ elif job_stage == "FAILED":
85
+ return None, "Job failed."
86
+ else:
87
+ return None, "Job ID not found."
88
 
89
  def gradio_app():
 
90
  with gr.Blocks() as demo:
91
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
 
 
92
  with gr.Row():
93
  image1 = gr.Image(label="Image Prompt 1", type="filepath")
94
  image2 = gr.Image(label="Image Prompt 2", type="filepath")
95
  image3 = gr.Image(label="Image Prompt 3", type="filepath")
96
  image4 = gr.Image(label="Image Prompt 4", type="filepath")
97
+ output_image = gr.Image(label="Generated Image")
98
+ status = gr.Textbox(label="Status")
99
 
100
  generate_button = gr.Button("Generate Image")
101
+ generate_button.click(image_prompt, inputs=[prompt, image1, image2, image3, image4], outputs=[output_image, status])
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  demo.launch()
104
 
105
  if __name__ == "__main__":
106
+ gradio_app()