vilarin commited on
Commit
09a21c7
·
verified ·
1 Parent(s): 2516da3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -47,7 +47,6 @@ class ModelWrapper:
47
  self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
48
  self.num_step = num_step
49
 
50
- @spaces.GPU()
51
  def create_generator(self, model_id, checkpoint_path):
52
  generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
53
  state_dict = torch.load(checkpoint_path, map_location="cuda")
@@ -149,6 +148,7 @@ class ModelWrapper:
149
 
150
  return output_image_list, f"Run successfully in {(end_time-start_time):.2f} seconds"
151
 
 
152
  def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
153
  alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
154
  beta_prod_t = 1 - alpha_prod_t
@@ -185,6 +185,7 @@ class SDXLTextEncoder(torch.nn.Module):
185
 
186
  return prompt_embeds, pooled_prompt_embeds
187
 
 
188
  def create_demo():
189
  TITLE = "# DMD2-SDXL Demo"
190
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
 
47
  self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
48
  self.num_step = num_step
49
 
 
50
  def create_generator(self, model_id, checkpoint_path):
51
  generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
52
  state_dict = torch.load(checkpoint_path, map_location="cuda")
 
148
 
149
  return output_image_list, f"Run successfully in {(end_time-start_time):.2f} seconds"
150
 
151
+ @spaces.GPU()
152
  def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
153
  alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
154
  beta_prod_t = 1 - alpha_prod_t
 
185
 
186
  return prompt_embeds, pooled_prompt_embeds
187
 
188
+
189
  def create_demo():
190
  TITLE = "# DMD2-SDXL Demo"
191
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"