big_ai / model_pipelines.py
xiehuangbao1122's picture
Update model_pipelines.py
5007cdb verified
raw
history blame
1.46 kB
from diffusers import StableDiffusionPipeline, DiffusionPipeline
import torch
from PIL import Image
import time
device = "cpu"
torch_dtype = torch.float32 # CPU上使用float32更稳定
def load_pipelines():
print("正在加载模型(CPU模式)...")
# 使用新版API加载模型
models = {
"sd_v1_5": StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch_dtype,
safety_checker=None,
requires_safety_checker=False
).to(device),
"openjourney_v4": DiffusionPipeline.from_pretrained(
"prompthero/openjourney-v4",
torch_dtype=torch_dtype,
safety_checker=None
).to(device),
"ldm_256": DiffusionPipeline.from_pretrained(
"CompVis/ldm-text2im-large-256",
torch_dtype=torch_dtype
).to(device)
}
# 内存优化配置
for model in models.values():
model.enable_attention_slicing()
return models
def generate_all(pipelines, prompt, steps=20):
results = {}
for name, pipe in pipelines.items():
print(f"正在用 {name} 生成图像...")
start = time.time()
result = pipe(prompt, num_inference_steps=steps).images[0]
gen_time = time.time() - start
print(f"{name} 生成完成,耗时 {gen_time:.2f}秒")
results[name] = result
return results