import torch | |
from diffusers import StableDiffusionPipeline | |
def load_pipelines(device="cuda"): | |
model_ids = { | |
"sd_v1_5": "runwayml/stable-diffusion-v1-5", | |
"openjourney_v4": "prompthero/openjourney-v4", | |
"ldm_256": "CompVis/ldm-text2im-large-256" | |
} | |
pipes = {} | |
for name, mid in model_ids.items(): | |
pipe = StableDiffusionPipeline.from_pretrained(mid, torch_dtype=torch.float16) | |
pipe = pipe.to(device) | |
pipe.enable_attention_slicing() | |
pipes[name] = pipe | |
return pipes | |
def generate_all(pipes, prompt): | |
results = {} | |
for name, pipe in pipes.items(): | |
img = pipe(prompt, guidance_scale=7.5, num_inference_steps=30).images[0] | |
results[name] = img | |
return results | |