Update model_pipelines.py
Browse files- model_pipelines.py +6 -2
model_pipelines.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import torch
|
2 |
from diffusers import StableDiffusionPipeline
|
3 |
|
|
|
4 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
5 |
|
6 |
-
def load_pipelines(
|
7 |
model_ids = {
|
8 |
"sd_v1_5": "runwayml/stable-diffusion-v1-5",
|
9 |
"openjourney_v4": "prompthero/openjourney-v4",
|
@@ -11,7 +12,10 @@ def load_pipelines(device="cuda"):
|
|
11 |
}
|
12 |
pipes = {}
|
13 |
for name, mid in model_ids.items():
|
14 |
-
pipe = StableDiffusionPipeline.from_pretrained(
|
|
|
|
|
|
|
15 |
pipe = pipe.to(device)
|
16 |
pipe.enable_attention_slicing()
|
17 |
pipes[name] = pipe
|
|
|
1 |
import torch
|
2 |
from diffusers import StableDiffusionPipeline
|
3 |
|
4 |
+
# 自动判断是否有 CUDA,否则用 CPU
|
5 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
6 |
|
7 |
+
def load_pipelines():
|
8 |
model_ids = {
|
9 |
"sd_v1_5": "runwayml/stable-diffusion-v1-5",
|
10 |
"openjourney_v4": "prompthero/openjourney-v4",
|
|
|
12 |
}
|
13 |
pipes = {}
|
14 |
for name, mid in model_ids.items():
|
15 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
16 |
+
mid,
|
17 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
18 |
+
)
|
19 |
pipe = pipe.to(device)
|
20 |
pipe.enable_attention_slicing()
|
21 |
pipes[name] = pipe
|