File size: 1,461 Bytes
9d3dabe
8aca2a0
9d3dabe
 
8aca2a0
9d3dabe
 
b1f3d37
2b419ce
9d3dabe
 
5007cdb
9d3dabe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aca2a0
9d3dabe
5007cdb
9d3dabe
 
 
 
8aca2a0
9d3dabe
8aca2a0
9d3dabe
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from diffusers import StableDiffusionPipeline, DiffusionPipeline
import torch
from PIL import Image
import time

device = "cpu"
torch_dtype = torch.float32  # CPU上使用float32更稳定

def load_pipelines():
    print("正在加载模型(CPU模式)...")
    
    # 使用新版API加载模型
    models = {
        "sd_v1_5": StableDiffusionPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            torch_dtype=torch_dtype,
            safety_checker=None,
            requires_safety_checker=False
        ).to(device),
        
        "openjourney_v4": DiffusionPipeline.from_pretrained(
            "prompthero/openjourney-v4",
            torch_dtype=torch_dtype,
            safety_checker=None
        ).to(device),
        
        "ldm_256": DiffusionPipeline.from_pretrained(
            "CompVis/ldm-text2im-large-256",
            torch_dtype=torch_dtype
        ).to(device)
    }
    
    # 内存优化配置
    for model in models.values():
        model.enable_attention_slicing()
        
    return models

def generate_all(pipelines, prompt, steps=20):
    results = {}
    for name, pipe in pipelines.items():
        print(f"正在用 {name} 生成图像...")
        start = time.time()
        result = pipe(prompt, num_inference_steps=steps).images[0]
        gen_time = time.time() - start
        print(f"{name} 生成完成,耗时 {gen_time:.2f}秒")
        results[name] = result
    return results