xiehuangbao1122 commited on
Commit
a20155e
·
verified ·
1 Parent(s): 45f797c

Update model_pipelines.py

Browse files
Files changed (1) hide show
  1. model_pipelines.py +54 -38
model_pipelines.py CHANGED
@@ -2,47 +2,63 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline
2
  import torch
3
  from PIL import Image
4
  import time
 
 
5
 
6
- device = "cpu"
7
- torch_dtype = torch.float32 # CPU上使用float32更稳定
8
 
9
- def load_pipelines():
10
- print("正在加载模型(CPU模式)...")
11
-
12
- # 使用新版API加载模型
13
- models = {
14
- "sd_v1_5": StableDiffusionPipeline.from_pretrained(
15
- "runwayml/stable-diffusion-v1-5",
16
- torch_dtype=torch_dtype,
17
- safety_checker=None,
18
- requires_safety_checker=False
19
- ).to(device),
20
 
21
- "openjourney_v4": DiffusionPipeline.from_pretrained(
22
- "prompthero/openjourney-v4",
23
- torch_dtype=torch_dtype,
24
- safety_checker=None
25
- ).to(device),
26
-
27
- "ldm_256": DiffusionPipeline.from_pretrained(
28
- "CompVis/ldm-text2im-large-256",
29
- torch_dtype=torch_dtype
30
- ).to(device)
31
- }
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # 内存优化配置
34
- for model in models.values():
35
- model.enable_attention_slicing()
36
-
 
 
 
 
 
 
 
 
 
 
 
37
  return models
38
 
39
- def generate_all(pipelines, prompt, steps=20):
40
- results = {}
41
- for name, pipe in pipelines.items():
42
- print(f"正在用 {name} 生成图像...")
43
- start = time.time()
44
- result = pipe(prompt, num_inference_steps=steps).images[0]
45
- gen_time = time.time() - start
46
- print(f"{name} 生成完成,耗时 {gen_time:.2f}秒")
47
- results[name] = result
48
- return results
 
2
  import torch
3
  from PIL import Image
4
  import time
5
+ import logging
6
+ from tqdm import tqdm
7
 
8
+ logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
 
11
+ class ModelWrapper:
12
+ def __init__(self, pipe, name):
13
+ self.pipe = pipe
14
+ self.name = name
 
 
 
 
 
 
 
15
 
16
+ def generate(self, prompt, steps=20):
17
+ logger.info(f"Generating with {self.name}...")
18
+ try:
19
+ start = time.time()
20
+ result = self.pipe(
21
+ prompt,
22
+ num_inference_steps=steps,
23
+ guidance_scale=7.5,
24
+ output_type="pil"
25
+ ).images[0]
26
+ logger.info(f"{self.name} generated in {time.time()-start:.2f}s")
27
+ return result
28
+ except Exception as e:
29
+ logger.error(f"{self.name} failed: {str(e)}")
30
+ return Image.new("RGB", (512, 512), color="#FF0000")
31
+
32
+ def load_pipelines():
33
+ configs = [
34
+ ("sd_v1_5", "runwayml/stable-diffusion-v1-5", StableDiffusionPipeline,
35
+ {"safety_checker": None}),
36
+ ("openjourney_v4", "prompthero/openjourney-v4", DiffusionPipeline, {}),
37
+ ("ldm_256", "CompVis/ldm-text2im-large-256", DiffusionPipeline, {})
38
+ ]
39
 
40
+ models = {}
41
+ for name, repo, pipe_cls, kwargs in tqdm(configs, desc="Loading models"):
42
+ try:
43
+ pipe = pipe_cls.from_pretrained(
44
+ repo,
45
+ torch_dtype=torch.float32,
46
+ **kwargs
47
+ )
48
+ pipe = pipe.to("cpu")
49
+ pipe.enable_attention_slicing()
50
+ models[name] = ModelWrapper(pipe, name)
51
+ except Exception as e:
52
+ logger.error(f"Failed to load {name}: {str(e)}")
53
+ continue
54
+
55
  return models
56
 
57
+ def generate_all(models, prompt, steps=20):
58
+ if not models:
59
+ raise ValueError("No models loaded")
60
+
61
+ return {
62
+ name: model.generate(prompt, steps)
63
+ for name, model in models.items()
64
+ }