xiehuangbao1122 commited on
Commit
9d3dabe
·
verified ·
1 Parent(s): 75a6197

Update model_pipelines.py

Browse files
Files changed (1) hide show
  1. model_pipelines.py +41 -21
model_pipelines.py CHANGED
@@ -1,28 +1,48 @@
 
1
  import torch
2
- from diffusers import StableDiffusionPipeline
 
3
 
4
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
5
 
6
  def load_pipelines():
7
- model_ids = {
8
- "sd_v1_5": "runwayml/stable-diffusion-v1-5",
9
- "openjourney_v4": "prompthero/openjourney-v4",
10
- "realistic_vision": "SG161222/Realistic_Vision_V5.1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  }
12
- pipes = {}
13
- for name, mid in model_ids.items():
14
- pipe = StableDiffusionPipeline.from_pretrained(
15
- mid,
16
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
17
- )
18
- pipe = pipe.to(device)
19
- pipe.enable_attention_slicing()
20
- pipes[name] = pipe
21
- return pipes
22
 
23
- def generate_all(pipes, prompt):
24
  results = {}
25
- for name, pipe in pipes.items():
26
- img = pipe(prompt, guidance_scale=7.5, num_inference_steps=30).images[0]
27
- results[name] = img
28
- return results
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline
2
  import torch
3
+ from PIL import Image
4
+ import time
5
 
6
+ device = "cpu"
7
+ torch_dtype = torch.float32 # CPU上使用float32更稳定
8
 
9
  def load_pipelines():
10
+ print("正在加载模型(CPU模式)...")
11
+
12
+ # 使用较小的模型或优化配置
13
+ models = {
14
+ "sd_v1_5": StableDiffusionPipeline.from_pretrained(
15
+ "runwayml/stable-diffusion-v1-5",
16
+ torch_dtype=torch_dtype,
17
+ safety_checker=None,
18
+ requires_safety_checker=False
19
+ ).to(device),
20
+
21
+ "openjourney_v4": DiffusionPipeline.from_pretrained(
22
+ "prompthero/openjourney-v4",
23
+ torch_dtype=torch_dtype,
24
+ safety_checker=None
25
+ ).to(device),
26
+
27
+ "ldm_256": DiffusionPipeline.from_pretrained(
28
+ "CompVis/ldm-text2im-large-256",
29
+ torch_dtype=torch_dtype
30
+ ).to(device)
31
  }
32
+
33
+ # 启用内存优化
34
+ for model in models.values():
35
+ model.enable_attention_slicing()
36
+
37
+ return models
 
 
 
 
38
 
39
+ def generate_all(pipelines, prompt, steps=20):
40
  results = {}
41
+ for name, pipe in pipelines.items():
42
+ print(f"正在用 {name} 生成图像...")
43
+ start = time.time()
44
+ result = pipe(prompt, num_inference_steps=steps).images[0]
45
+ gen_time = time.time() - start
46
+ print(f"{name} 生成完成,耗时 {gen_time:.2f}秒")
47
+ results[name] = result
48
+ return results