frogleo commited on
Commit
23063b3
·
1 Parent(s): b82c0ac

继续完善逻辑

Browse files
Files changed (2) hide show
  1. app.py +24 -2
  2. utils.py +18 -1
app.py CHANGED
@@ -138,9 +138,31 @@ def generate(
138
  pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, scheduler)
139
 
140
  upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
141
-
142
 
143
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  except GenerationError as e:
145
  logger.warning(f"Generation validation error: {str(e)}")
146
  raise gr.Error(str(e))
 
138
  pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, scheduler)
139
 
140
  upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
 
141
 
142
+ latents = pipe(
143
+ prompt=prompt,
144
+ negative_prompt=negative_prompt,
145
+ width=width,
146
+ height=height,
147
+ guidance_scale=guidance_scale,
148
+ num_inference_steps=num_inference_steps,
149
+ generator=generator,
150
+ output_type="latent",
151
+ ).images
152
+
153
+ upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
154
+ images = upscaler_pipe(
155
+ prompt=prompt,
156
+ negative_prompt=negative_prompt,
157
+ image=upscaled_latents,
158
+ guidance_scale=guidance_scale,
159
+ num_inference_steps=num_inference_steps,
160
+ strength=upscaler_strength,
161
+ generator=generator,
162
+ output_type="pil",
163
+ ).images
164
+ return images[0]
165
+
166
  except GenerationError as e:
167
  logger.warning(f"Generation validation error: {str(e)}")
168
  raise gr.Error(str(e))
utils.py CHANGED
@@ -75,4 +75,21 @@ def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
75
  ),
76
  "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
77
  }
78
- return scheduler_factory_map.get(name, lambda: None)()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  ),
76
  "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
77
  }
78
+ return scheduler_factory_map.get(name, lambda: None)()
79
+
80
+ def common_upscale(
81
+ samples: torch.Tensor,
82
+ width: int,
83
+ height: int,
84
+ upscale_method: str,
85
+ ) -> torch.Tensor:
86
+ return torch.nn.functional.interpolate(
87
+ samples, size=(height, width), mode=upscale_method
88
+ )
89
+
90
+ def upscale(
91
+ samples: torch.Tensor, upscale_method: str, scale_by: float
92
+ ) -> torch.Tensor:
93
+ width = round(samples.shape[3] * scale_by)
94
+ height = round(samples.shape[2] * scale_by)
95
+ return common_upscale(samples, width, height, upscale_method)