linimi commited on
Commit
8aca2a0
·
verified ·
1 Parent(s): f6f74ac

Create model_pipelines.py

Browse files
Files changed (1) hide show
  1. model_pipelines.py +23 -0
model_pipelines.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline
3
+
4
+ def load_pipelines(device="cuda"):
5
+ model_ids = {
6
+ "sd_v1_5": "runwayml/stable-diffusion-v1-5",
7
+ "openjourney_v4": "prompthero/openjourney-v4",
8
+ "ldm_256": "CompVis/ldm-text2im-large-256"
9
+ }
10
+ pipes = {}
11
+ for name, mid in model_ids.items():
12
+ pipe = StableDiffusionPipeline.from_pretrained(mid, torch_dtype=torch.float16)
13
+ pipe = pipe.to(device)
14
+ pipe.enable_attention_slicing()
15
+ pipes[name] = pipe
16
+ return pipes
17
+
18
+ def generate_all(pipes, prompt):
19
+ results = {}
20
+ for name, pipe in pipes.items():
21
+ img = pipe(prompt, guidance_scale=7.5, num_inference_steps=30).images[0]
22
+ results[name] = img
23
+ return results