big_ai / model_pipelines.py
linimi's picture
Create model_pipelines.py
8aca2a0 verified
raw
history blame
751 Bytes
import torch
from diffusers import StableDiffusionPipeline
def load_pipelines(device="cuda"):
model_ids = {
"sd_v1_5": "runwayml/stable-diffusion-v1-5",
"openjourney_v4": "prompthero/openjourney-v4",
"ldm_256": "CompVis/ldm-text2im-large-256"
}
pipes = {}
for name, mid in model_ids.items():
pipe = StableDiffusionPipeline.from_pretrained(mid, torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
pipes[name] = pipe
return pipes
def generate_all(pipes, prompt):
results = {}
for name, pipe in pipes.items():
img = pipe(prompt, guidance_scale=7.5, num_inference_steps=30).images[0]
results[name] = img
return results