File size: 751 Bytes
8aca2a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from diffusers import StableDiffusionPipeline
def load_pipelines(device="cuda"):
model_ids = {
"sd_v1_5": "runwayml/stable-diffusion-v1-5",
"openjourney_v4": "prompthero/openjourney-v4",
"ldm_256": "CompVis/ldm-text2im-large-256"
}
pipes = {}
for name, mid in model_ids.items():
pipe = StableDiffusionPipeline.from_pretrained(mid, torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
pipes[name] = pipe
return pipes
def generate_all(pipes, prompt):
results = {}
for name, pipe in pipes.items():
img = pipe(prompt, guidance_scale=7.5, num_inference_steps=30).images[0]
results[name] = img
return results
|