Spaces:
Running
on
Zero
Running
on
Zero
继续完善逻辑
Browse files
app.py
CHANGED
@@ -138,9 +138,31 @@ def generate(
|
|
138 |
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, scheduler)
|
139 |
|
140 |
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
|
141 |
-
|
142 |
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
except GenerationError as e:
|
145 |
logger.warning(f"Generation validation error: {str(e)}")
|
146 |
raise gr.Error(str(e))
|
|
|
138 |
pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, scheduler)
|
139 |
|
140 |
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
|
|
|
141 |
|
142 |
+
latents = pipe(
|
143 |
+
prompt=prompt,
|
144 |
+
negative_prompt=negative_prompt,
|
145 |
+
width=width,
|
146 |
+
height=height,
|
147 |
+
guidance_scale=guidance_scale,
|
148 |
+
num_inference_steps=num_inference_steps,
|
149 |
+
generator=generator,
|
150 |
+
output_type="latent",
|
151 |
+
).images
|
152 |
+
|
153 |
+
upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
|
154 |
+
images = upscaler_pipe(
|
155 |
+
prompt=prompt,
|
156 |
+
negative_prompt=negative_prompt,
|
157 |
+
image=upscaled_latents,
|
158 |
+
guidance_scale=guidance_scale,
|
159 |
+
num_inference_steps=num_inference_steps,
|
160 |
+
strength=upscaler_strength,
|
161 |
+
generator=generator,
|
162 |
+
output_type="pil",
|
163 |
+
).images
|
164 |
+
return images[0]
|
165 |
+
|
166 |
except GenerationError as e:
|
167 |
logger.warning(f"Generation validation error: {str(e)}")
|
168 |
raise gr.Error(str(e))
|
utils.py
CHANGED
@@ -75,4 +75,21 @@ def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
|
|
75 |
),
|
76 |
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
|
77 |
}
|
78 |
-
return scheduler_factory_map.get(name, lambda: None)()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
),
|
76 |
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
|
77 |
}
|
78 |
+
return scheduler_factory_map.get(name, lambda: None)()
|
79 |
+
|
80 |
+
def common_upscale(
|
81 |
+
samples: torch.Tensor,
|
82 |
+
width: int,
|
83 |
+
height: int,
|
84 |
+
upscale_method: str,
|
85 |
+
) -> torch.Tensor:
|
86 |
+
return torch.nn.functional.interpolate(
|
87 |
+
samples, size=(height, width), mode=upscale_method
|
88 |
+
)
|
89 |
+
|
90 |
+
def upscale(
|
91 |
+
samples: torch.Tensor, upscale_method: str, scale_by: float
|
92 |
+
) -> torch.Tensor:
|
93 |
+
width = round(samples.shape[3] * scale_by)
|
94 |
+
height = round(samples.shape[2] * scale_by)
|
95 |
+
return common_upscale(samples, width, height, upscale_method)
|