|
from diffusers import StableDiffusionPipeline, DiffusionPipeline |
|
import torch |
|
from PIL import Image |
|
import time |
|
|
|
device = "cpu" |
|
torch_dtype = torch.float32 |
|
|
|
def load_pipelines(): |
|
print("正在加载模型(CPU模式)...") |
|
|
|
|
|
models = { |
|
"sd_v1_5": StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch_dtype, |
|
safety_checker=None, |
|
requires_safety_checker=False |
|
).to(device), |
|
|
|
"openjourney_v4": DiffusionPipeline.from_pretrained( |
|
"prompthero/openjourney-v4", |
|
torch_dtype=torch_dtype, |
|
safety_checker=None |
|
).to(device), |
|
|
|
"ldm_256": DiffusionPipeline.from_pretrained( |
|
"CompVis/ldm-text2im-large-256", |
|
torch_dtype=torch_dtype |
|
).to(device) |
|
} |
|
|
|
|
|
for model in models.values(): |
|
model.enable_attention_slicing() |
|
|
|
return models |
|
|
|
def generate_all(pipelines, prompt, steps=20): |
|
results = {} |
|
for name, pipe in pipelines.items(): |
|
print(f"正在用 {name} 生成图像...") |
|
start = time.time() |
|
result = pipe(prompt, num_inference_steps=steps).images[0] |
|
gen_time = time.time() - start |
|
print(f"{name} 生成完成,耗时 {gen_time:.2f}秒") |
|
results[name] = result |
|
return results |