big_ai / model_pipelines.py
linimi's picture
Update model_pipelines.py
5491f48 verified
raw
history blame
878 Bytes
import torch
from diffusers import StableDiffusionPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_pipelines():
model_ids = {
"sd_v1_5": "runwayml/stable-diffusion-v1-5",
"openjourney_v4": "prompthero/openjourney-v4",
"realistic_vision": "SG161222/Realistic_Vision_V5.1"
}
pipes = {}
for name, mid in model_ids.items():
pipe = StableDiffusionPipeline.from_pretrained(
mid,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
pipes[name] = pipe
return pipes
def generate_all(pipes, prompt):
results = {}
for name, pipe in pipes.items():
img = pipe(prompt, guidance_scale=7.5, num_inference_steps=30).images[0]
results[name] = img
return results