ehristoforu commited on
Commit
83ddd58
·
1 Parent(s): 79c0f35

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import requests
4
+ import time
5
+ import json
6
+ import base64
7
+ import os
8
+ from PIL import Image
9
+ from io import BytesIO
10
+
11
+ class Prodia:
12
+ def __init__(self, api_key, base=None):
13
+ self.base = base or "https://api.prodia.com/v1"
14
+ self.headers = {
15
+ "X-Prodia-Key": api_key
16
+ }
17
+
18
+ def generate(self, params):
19
+ response = self._post(f"{self.base}/sdxl/generate", params)
20
+ return response.json()
21
+
22
+ def transform(self, params):
23
+ response = self._post(f"{self.base}/sd/transform", params)
24
+ return response.json()
25
+
26
+ def controlnet(self, params):
27
+ response = self._post(f"{self.base}/sd/controlnet", params)
28
+ return response.json()
29
+
30
+ def get_job(self, job_id):
31
+ response = self._get(f"{self.base}/job/{job_id}")
32
+ return response.json()
33
+
34
+ def wait(self, job):
35
+ job_result = job
36
+
37
+ while job_result['status'] not in ['succeeded', 'failed']:
38
+ time.sleep(0.25)
39
+ job_result = self.get_job(job['job'])
40
+
41
+ return job_result
42
+
43
+ def list_models(self):
44
+ response = self._get(f"{self.base}/models/list")
45
+ return response.json()
46
+
47
+ def _post(self, url, params):
48
+ headers = {
49
+ **self.headers,
50
+ "Content-Type": "application/json"
51
+ }
52
+ response = requests.post(url, headers=headers, data=json.dumps(params))
53
+
54
+ if response.status_code != 200:
55
+ raise Exception(f"Bad Prodia Response: {response.status_code}")
56
+
57
+ return response
58
+
59
+ def _get(self, url):
60
+ response = requests.get(url, headers=self.headers)
61
+
62
+ if response.status_code != 200:
63
+ raise Exception(f"Bad Prodia Response: {response.status_code}")
64
+
65
+ return response
66
+
67
+
68
+ def image_to_base64(image_path):
69
+ # Open the image with PIL
70
+ with Image.open(image_path) as image:
71
+ # Convert the image to bytes
72
+ buffered = BytesIO()
73
+ image.save(buffered, format="PNG") # You can change format to PNG if needed
74
+
75
+ # Encode the bytes to base64
76
+ img_str = base64.b64encode(buffered.getvalue())
77
+
78
+ return img_str.decode('utf-8') # Convert bytes to string
79
+
80
+
81
+
82
+ prodia_client = Prodia(api_key=os.getenv("PRODIA_API_KEY"))
83
+
84
+ def flip_text(prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed):
85
+ result = prodia_client.generate({
86
+ "prompt": prompt,
87
+ "negative_prompt": negative_prompt,
88
+ "model": model,
89
+ "steps": steps,
90
+ "sampler": sampler,
91
+ "cfg_scale": cfg_scale,
92
+ "width": width,
93
+ "height": height,
94
+ "seed": seed
95
+ })
96
+
97
+ job = prodia_client.wait(result)
98
+
99
+ return job["imageUrl"]
100
+
101
+ css = """
102
+ #generate {
103
+ height: 100%;
104
+ }
105
+ """
106
+
107
+ with gr.Blocks(css=css, theme="Base") as demo:
108
+
109
+
110
+
111
+ with gr.Row():
112
+ gr.Markdown("<h1><center>Stable Diffusion XL</center></h1>")
113
+ with gr.Tab("Playground"):
114
+ with gr.Row():
115
+ with gr.Column(scale=6, min_width=600):
116
+ prompt = gr.Textbox(label="Prompt", placeholder="beautiful cat, 8k", show_label=True, lines=2)
117
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="text, blurry, fuzziness", placeholder="text, blurry, fuzziness", show_label=True, lines=3)
118
+ with gr.Column():
119
+ text_button = gr.Button("Generate", variant='primary', elem_id="generate")
120
+
121
+ with gr.Row():
122
+
123
+
124
+
125
+ with gr.Column(scale=2):
126
+ image_output = gr.Image()
127
+
128
+ with gr.Accordion("Advanced options", open=False):
129
+ with gr.Row():
130
+ with gr.Column(scale=6):
131
+ model = gr.Dropdown(interactive=True,value="sd_xl_base_1.0.safetensors [be9edd61]", show_label=True, label="Stable Diffusion Checkpoint", choices=[
132
+ "sd_xl_base_1.0.safetensors [be9edd61]",
133
+ "dynavisionXL_0411.safetensors [c39cc051]",
134
+ "dreamshaperXL10_alpha2.safetensors [c8afe2ef]",
135
+ ])
136
+
137
+ with gr.Row():
138
+ with gr.Column(scale=1):
139
+ sampler = gr.Dropdown(value="DPM++ SDE", show_label=True, label="Sampler", choices=[
140
+ "Euler",
141
+ "Euler a",
142
+ "LMS",
143
+ "Heun",
144
+ "DPM2",
145
+ "DPM2 a",
146
+ "DPM++ 2S a",
147
+ "DPM++ 2M",
148
+ "DPM++ SDE",
149
+ "DPM fast",
150
+ "DPM adaptive",
151
+ "LMS Karras",
152
+ "DPM2 Karras",
153
+ "DPM2 a Karras",
154
+ "DPM++ 2S a Karras",
155
+ "DPM++ 2M Karras",
156
+ "DPM++ SDE Karras",
157
+ "DDIM",
158
+ "PLMS",
159
+ ])
160
+
161
+ with gr.Column(scale=1):
162
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=30, step=1)
163
+
164
+ with gr.Row():
165
+ with gr.Column(scale=1):
166
+ width = gr.Slider(label="Width", maximum=1024, value=1024, step=8)
167
+ height = gr.Slider(label="Height", maximum=1024, value=1024, step=8)
168
+
169
+ with gr.Column(scale=1):
170
+ batch_size = gr.Slider(label="Batch Size", maximum=1, value=1)
171
+ batch_count = gr.Slider(label="Batch Count", maximum=1, value=1)
172
+
173
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1)
174
+ seed = gr.Number(label="Seed", value=-1, info="""'-1' is random seed""")
175
+
176
+
177
+ text_button.click(flip_text, inputs=[prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed], outputs=image_output)
178
+
179
+ demo.queue(concurrency_count=1)
180
+ demo.launch(debug=False, share=False, show_error=False, show_api=False)