linimi commited on
Commit
2b419ce
·
verified ·
1 Parent(s): b1f3d37

Update model_pipelines.py

Browse files
Files changed (1) hide show
  1. model_pipelines.py +6 -2
model_pipelines.py CHANGED
@@ -1,9 +1,10 @@
1
  import torch
2
  from diffusers import StableDiffusionPipeline
3
 
 
4
  device = "cuda" if torch.cuda.is_available() else "cpu"
5
 
6
- def load_pipelines(device="cuda"):
7
  model_ids = {
8
  "sd_v1_5": "runwayml/stable-diffusion-v1-5",
9
  "openjourney_v4": "prompthero/openjourney-v4",
@@ -11,7 +12,10 @@ def load_pipelines(device="cuda"):
11
  }
12
  pipes = {}
13
  for name, mid in model_ids.items():
14
- pipe = StableDiffusionPipeline.from_pretrained(mid, torch_dtype=torch.float16)
 
 
 
15
  pipe = pipe.to(device)
16
  pipe.enable_attention_slicing()
17
  pipes[name] = pipe
 
1
  import torch
2
  from diffusers import StableDiffusionPipeline
3
 
4
+ # 自动判断是否有 CUDA,否则用 CPU
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
 
7
+ def load_pipelines():
8
  model_ids = {
9
  "sd_v1_5": "runwayml/stable-diffusion-v1-5",
10
  "openjourney_v4": "prompthero/openjourney-v4",
 
12
  }
13
  pipes = {}
14
  for name, mid in model_ids.items():
15
+ pipe = StableDiffusionPipeline.from_pretrained(
16
+ mid,
17
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
18
+ )
19
  pipe = pipe.to(device)
20
  pipe.enable_attention_slicing()
21
  pipes[name] = pipe