Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import os | |
import random | |
import numpy as np | |
import json | |
import torch | |
import uuid | |
from PIL import Image, PngImagePlugin | |
from datetime import datetime | |
from dataclasses import dataclass | |
from typing import Callable, Dict, Optional, Tuple, Any, List | |
from diffusers import ( | |
DDIMScheduler, | |
DPMSolverMultistepScheduler, | |
DPMSolverSinglestepScheduler, | |
EulerAncestralDiscreteScheduler, | |
EulerDiscreteScheduler, | |
AutoencoderKL, | |
StableDiffusionXLPipeline, | |
) | |
import logging | |
def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str] = None, vae: Optional[AutoencoderKL] = None) -> Any: | |
"""Load the Stable Diffusion pipeline.""" | |
try: | |
pipeline = ( | |
StableDiffusionXLPipeline.from_single_file | |
if model_name.endswith(".safetensors") | |
else StableDiffusionXLPipeline.from_pretrained | |
) | |
pipe = pipeline( | |
model_name, | |
vae=vae, | |
torch_dtype=torch.float16, | |
custom_pipeline="lpw_stable_diffusion_xl", | |
use_safetensors=True, | |
add_watermarker=False | |
) | |
pipe.to(device) | |
return pipe | |
except Exception as e: | |
logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True) | |
raise | |
def seed_everything(seed: int) -> torch.Generator: | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
generator = torch.Generator() | |
generator.manual_seed(seed) | |
return generator | |
def preprocess_image_dimensions(width, height): | |
if width % 8 != 0: | |
width = width - (width % 8) | |
if height % 8 != 0: | |
height = height - (height % 8) | |
return width, height | |
def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]: | |
scheduler_factory_map = { | |
"DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config( | |
scheduler_config, use_karras_sigmas=True | |
), | |
"DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config( | |
scheduler_config, use_karras_sigmas=True | |
), | |
"DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config( | |
scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++" | |
), | |
"Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config), | |
"Euler a": lambda: EulerAncestralDiscreteScheduler.from_config( | |
scheduler_config | |
), | |
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config), | |
} | |
return scheduler_factory_map.get(name, lambda: None)() | |
def common_upscale( | |
samples: torch.Tensor, | |
width: int, | |
height: int, | |
upscale_method: str, | |
) -> torch.Tensor: | |
return torch.nn.functional.interpolate( | |
samples, size=(height, width), mode=upscale_method | |
) | |
def upscale( | |
samples: torch.Tensor, upscale_method: str, scale_by: float | |
) -> torch.Tensor: | |
width = round(samples.shape[3] * scale_by) | |
height = round(samples.shape[2] * scale_by) | |
return common_upscale(samples, width, height, upscale_method) | |
def free_memory() -> None: | |
"""Free up GPU and system memory.""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
gc.collect() | |
def save_image(image, output_dir): | |
filename = str(uuid.uuid4()) + ".jpg" | |
os.makedirs(output_dir, exist_ok=True) | |
filepath = os.path.join(output_dir, filename) | |
image.save(filepath, "JPEG", quality=80) | |
return filepath |