Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -26,11 +26,11 @@ USE_TORCH_COMPILE = False
|
|
| 26 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
| 27 |
PREVIEW_IMAGES = True
|
| 28 |
|
| 29 |
-
dtype = torch.
|
| 30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
-
prior_pipeline = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=
|
| 33 |
-
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("diffusers/StableCascade-decoder", torch_dtype=
|
| 34 |
|
| 35 |
if ENABLE_CPU_OFFLOAD:
|
| 36 |
prior_pipeline.enable_model_cpu_offload()
|
|
@@ -46,6 +46,7 @@ if torch.cuda.is_available():
|
|
| 46 |
if PREVIEW_IMAGES:
|
| 47 |
previewer = Previewer()
|
| 48 |
previewer.load_state_dict(torch.load("previewer/previewer_v1_100k.pt")["state_dict"])
|
|
|
|
| 49 |
def callback_prior(i, t, latents):
|
| 50 |
output = previewer(latents)
|
| 51 |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
|
|
@@ -81,9 +82,9 @@ def generate(
|
|
| 81 |
num_images_per_prompt: int = 2,
|
| 82 |
#profile: gr.OAuthProfile | None = None,
|
| 83 |
) -> PIL.Image.Image:
|
| 84 |
-
prior_pipeline.to(
|
| 85 |
-
decoder_pipeline.to(
|
| 86 |
-
previewer.eval().requires_grad_(False).to(device).to(dtype)
|
| 87 |
generator = torch.Generator().manual_seed(seed)
|
| 88 |
prior_output = prior_pipeline(
|
| 89 |
prompt=prompt,
|
|
|
|
| 26 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
| 27 |
PREVIEW_IMAGES = True
|
| 28 |
|
| 29 |
+
dtype = torch.bfloat16
|
| 30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
+
prior_pipeline = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=dtype).to(device)
|
| 33 |
+
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("diffusers/StableCascade-decoder", torch_dtype=dtype).to(device)
|
| 34 |
|
| 35 |
if ENABLE_CPU_OFFLOAD:
|
| 36 |
prior_pipeline.enable_model_cpu_offload()
|
|
|
|
| 46 |
if PREVIEW_IMAGES:
|
| 47 |
previewer = Previewer()
|
| 48 |
previewer.load_state_dict(torch.load("previewer/previewer_v1_100k.pt")["state_dict"])
|
| 49 |
+
previewer.eval().requires_grad_(False).to(device).to(dtype)
|
| 50 |
def callback_prior(i, t, latents):
|
| 51 |
output = previewer(latents)
|
| 52 |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
|
|
|
|
| 82 |
num_images_per_prompt: int = 2,
|
| 83 |
#profile: gr.OAuthProfile | None = None,
|
| 84 |
) -> PIL.Image.Image:
|
| 85 |
+
#prior_pipeline.to(device)
|
| 86 |
+
#decoder_pipeline.to(device)
|
| 87 |
+
#previewer.eval().requires_grad_(False).to(device).to(dtype)
|
| 88 |
generator = torch.Generator().manual_seed(seed)
|
| 89 |
prior_output = prior_pipeline(
|
| 90 |
prompt=prompt,
|