omer11a commited on
Commit
dfc5415
·
1 Parent(s): d157117

Download model locally

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -8,13 +8,12 @@ from PIL import Image, ImageDraw
8
  from diffusers import DDIMScheduler
9
  from diffusers.models.attention_processor import AttnProcessor2_0
10
  from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
11
- from injection_utils import register_attention_editor_diffusers, unregister_attention_editor_diffusers
12
  from bounded_attention import BoundedAttention
13
  from pytorch_lightning import seed_everything
14
 
15
- from functools import partial
16
-
17
- MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
18
  RESOLUTION = 256
19
  MIN_SIZE = 0.01
20
  WHITE = 255
@@ -114,7 +113,6 @@ FOOTNOTE = """
114
 
115
 
116
  def inference(
117
- model,
118
  boxes,
119
  prompts,
120
  subject_token_indices,
@@ -135,7 +133,10 @@ def inference(
135
  raise gr.Error("cuda is not available")
136
 
137
  device = torch.device("cuda")
138
- model.to(device).half()
 
 
 
139
 
140
  seed_everything(seed)
141
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
@@ -160,10 +161,7 @@ def inference(
160
  )
161
 
162
  register_attention_editor_diffusers(model, editor)
163
- images = model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
164
- unregister_attention_editor_diffusers(model)
165
- model.double().to(torch.device("cpu"))
166
- return images
167
 
168
 
169
  @spaces.GPU(duration=300)
@@ -198,7 +196,7 @@ def generate(
198
  prompts = [prompt.strip(".").strip(",").strip()] * batch_size
199
 
200
  images = inference(
201
- model, boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
202
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
203
  num_iterations, loss_threshold, num_guidance_steps, seed)
204
 
@@ -255,10 +253,9 @@ def clear(batch_size):
255
  def main():
256
  nltk.download("averaged_perceptron_tagger")
257
 
258
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
259
- model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler)
260
- model.unet.set_attn_processor(AttnProcessor2_0())
261
- model.enable_sequential_cpu_offload()
262
 
263
  with gr.Blocks(
264
  css=CSS,
@@ -330,7 +327,7 @@ def main():
330
  )
331
 
332
  generate_image_button.click(
333
- fn=partial(generate, model),
334
  inputs=[
335
  prompt, subject_token_indices, filter_token_indices, num_tokens,
336
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
 
8
  from diffusers import DDIMScheduler
9
  from diffusers.models.attention_processor import AttnProcessor2_0
10
  from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
11
+ from injection_utils import register_attention_editor_diffusers
12
  from bounded_attention import BoundedAttention
13
  from pytorch_lightning import seed_everything
14
 
15
+ REMOTE_MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
16
+ LOCAL_MODEL_PATH = "./model"
 
17
  RESOLUTION = 256
18
  MIN_SIZE = 0.01
19
  WHITE = 255
 
113
 
114
 
115
  def inference(
 
116
  boxes,
117
  prompts,
118
  subject_token_indices,
 
133
  raise gr.Error("cuda is not available")
134
 
135
  device = torch.device("cuda")
136
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
137
+ model = StableDiffusionXLPipeline.from_pretrained(LOCAL_MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16).to(device)
138
+ model.unet.set_attn_processor(AttnProcessor2_0())
139
+ model.enable_sequential_cpu_offload()
140
 
141
  seed_everything(seed)
142
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
 
161
  )
162
 
163
  register_attention_editor_diffusers(model, editor)
164
+ return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
 
 
 
165
 
166
 
167
  @spaces.GPU(duration=300)
 
196
  prompts = [prompt.strip(".").strip(",").strip()] * batch_size
197
 
198
  images = inference(
199
+ boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
200
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
201
  num_iterations, loss_threshold, num_guidance_steps, seed)
202
 
 
253
  def main():
254
  nltk.download("averaged_perceptron_tagger")
255
 
256
+ model = StableDiffusionXLPipeline.from_pretrained(REMOTE_MODEL_PATH, scheduler=scheduler)
257
+ model.save_pretrained(LOCAL_MODEL_PATH)
258
+ del model
 
259
 
260
  with gr.Blocks(
261
  css=CSS,
 
327
  )
328
 
329
  generate_image_button.click(
330
+ fn=generate,
331
  inputs=[
332
  prompt, subject_token_indices, filter_token_indices, num_tokens,
333
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,