ktrndy commited on
Commit
50ef0b7
·
verified ·
1 Parent(s): 407939e

Create app_new.py

Browse files
Files changed (1) hide show
  1. app_new.py +199 -0
app_new.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import os
5
+ import torch
6
+ from diffusers import StableDiffusionPipeline
7
+ from peft import PeftModel, LoraConfig
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
11
+
12
+ if torch.cuda.is_available():
13
+ torch_dtype = torch.float16
14
+ else:
15
+ torch_dtype = torch.float32
16
+
17
+ MAX_SEED = np.iinfo(np.int32).max
18
+ MAX_IMAGE_SIZE = 1024
19
+
20
+
21
+ # @spaces.GPU #[uncomment to use ZeroGPU]
22
+ def infer(
23
+ prompt,
24
+ negative_prompt,
25
+ width=512,
26
+ height=512,
27
+ model_id=model_id_default,
28
+ seed=42,
29
+ guidance_scale=7.0,
30
+ lora_scale=1.0,
31
+ num_inference_steps=20,
32
+ progress=gr.Progress(track_tqdm=True),
33
+ ):
34
+ generator = torch.Generator(device).manual_seed(seed)
35
+
36
+ ckpt_dir='./model_output'
37
+ unet_sub_dir = os.path.join(ckpt_dir, "unet")
38
+ text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
39
+
40
+ if model_id is None:
41
+ raise ValueError("Please specify the base model name or path")
42
+
43
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
44
+ torch_dtype=torch_dtype,
45
+ safety_checker=None).to(device)
46
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
47
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
48
+
49
+ pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()})
50
+ pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()})
51
+
52
+ if torch_dtype in (torch.float16, torch.bfloat16):
53
+ pipe.unet.half()
54
+ pipe.text_encoder.half()
55
+
56
+ pipe.to(device)
57
+
58
+ image = pipe(
59
+ prompt=prompt,
60
+ negative_prompt=negative_prompt,
61
+ guidance_scale=guidance_scale,
62
+ num_inference_steps=num_inference_steps,
63
+ width=width,
64
+ height=height,
65
+ generator=generator,
66
+ ).images[0]
67
+
68
+ return image
69
+
70
+ css = """
71
+ #col-container {
72
+ margin: 0 auto;
73
+ max-width: 640px;
74
+ }
75
+ """
76
+
77
+ def controlnet_params(show_extra):
78
+ return gr.update(visible=show_extra)
79
+
80
+ with gr.Blocks(css=css, fill_height=True) as demo:
81
+ with gr.Column(elem_id="col-container"):
82
+ gr.Markdown(" # Text-to-Image demo")
83
+
84
+ with gr.Row():
85
+ model_id = gr.Textbox(
86
+ label="Model ID",
87
+ max_lines=1,
88
+ placeholder="Enter model id",
89
+ value=model_id_default,
90
+ )
91
+
92
+ prompt = gr.Textbox(
93
+ label="Prompt",
94
+ max_lines=1,
95
+ placeholder="Enter your prompt",
96
+ )
97
+
98
+ negative_prompt = gr.Textbox(
99
+ label="Negative prompt",
100
+ max_lines=1,
101
+ placeholder="Enter your negative prompt",
102
+ )
103
+
104
+ with gr.Row():
105
+ seed = gr.Number(
106
+ label="Seed",
107
+ minimum=0,
108
+ maximum=MAX_SEED,
109
+ step=1,
110
+ value=42,
111
+ )
112
+
113
+ guidance_scale = gr.Slider(
114
+ label="Guidance scale",
115
+ minimum=0.0,
116
+ maximum=30.0,
117
+ step=0.1,
118
+ value=7.0, # Replace with defaults that work for your model
119
+ )
120
+ with gr.Row():
121
+ lora_scale = gr.Slider(
122
+ label="LoRA scale",
123
+ minimum=0.0,
124
+ maximum=1.0,
125
+ step=0.01,
126
+ value=1.0,
127
+ )
128
+
129
+ num_inference_steps = gr.Slider(
130
+ label="Number of inference steps",
131
+ minimum=1,
132
+ maximum=100,
133
+ step=1,
134
+ value=20, # Replace with defaults that work for your model
135
+ )
136
+ with gr.Row():
137
+ controlnet_checkbox = gr.Checkbox(
138
+ label="ControlNet",
139
+ )
140
+ with gr.Group(visible=False) as controlnet_params:
141
+ control_strength = gr.Slider(
142
+ label="ControlNet conditioning scale",
143
+ minimum=0.0,
144
+ maximum=1.0,
145
+ step=0.01,
146
+ value=1.0,
147
+ )
148
+ control_mode = gr.Dropdown(
149
+ label="ControlNet mode",
150
+ choises=["edge_detection"]
151
+ )
152
+ controlnet_checkbox.change(
153
+ controlnet_params,
154
+ inputs=controlnet_checkbox,
155
+ outputs=controlnet_params
156
+ )
157
+
158
+ with gr.Accordion("Optional Settings", open=False):
159
+
160
+ with gr.Row():
161
+ width = gr.Slider(
162
+ label="Width",
163
+ minimum=256,
164
+ maximum=MAX_IMAGE_SIZE,
165
+ step=32,
166
+ value=512, # Replace with defaults that work for your model
167
+ )
168
+
169
+ height = gr.Slider(
170
+ label="Height",
171
+ minimum=256,
172
+ maximum=MAX_IMAGE_SIZE,
173
+ step=32,
174
+ value=512, # Replace with defaults that work for your model
175
+ )
176
+
177
+ run_button = gr.Button("Run", scale=0, variant="primary")
178
+ result = gr.Image(label="Result", show_label=False)
179
+
180
+ gr.on(
181
+ triggers=[run_button.click],
182
+ fn=infer,
183
+ inputs=[
184
+ prompt,
185
+ negative_prompt,
186
+ width,
187
+ height,
188
+ model_id,
189
+ seed,
190
+ guidance_scale,
191
+ lora_scale,
192
+ num_inference_steps
193
+
194
+ ],
195
+ outputs=[result],
196
+ )
197
+
198
+ if __name__ == "__main__":
199
+ demo.launch()