|
from diffusers import StableDiffusionPipeline, DiffusionPipeline |
|
import torch |
|
from PIL import Image |
|
import time |
|
import logging |
|
from tqdm import tqdm |
|
|
|
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ModelWrapper: |
|
def __init__(self, pipe, name): |
|
self.pipe = pipe |
|
self.name = name |
|
|
|
def generate(self, prompt, steps=20): |
|
logger.info(f"Generating with {self.name}...") |
|
try: |
|
start = time.time() |
|
result = self.pipe( |
|
prompt, |
|
num_inference_steps=steps, |
|
guidance_scale=7.5, |
|
output_type="pil" |
|
).images[0] |
|
logger.info(f"{self.name} generated in {time.time()-start:.2f}s") |
|
return result |
|
except Exception as e: |
|
logger.error(f"{self.name} failed: {str(e)}") |
|
return Image.new("RGB", (512, 512), color="#FF0000") |
|
|
|
def load_pipelines(): |
|
configs = [ |
|
("sd_v1_5", "runwayml/stable-diffusion-v1-5", StableDiffusionPipeline, |
|
{"safety_checker": None}), |
|
("openjourney_v4", "prompthero/openjourney-v4", DiffusionPipeline, {}), |
|
("ldm_256", "CompVis/ldm-text2im-large-256", DiffusionPipeline, {}) |
|
] |
|
|
|
models = {} |
|
for name, repo, pipe_cls, kwargs in tqdm(configs, desc="Loading models"): |
|
try: |
|
pipe = pipe_cls.from_pretrained( |
|
repo, |
|
torch_dtype=torch.float32, |
|
**kwargs |
|
) |
|
pipe = pipe.to("cpu") |
|
pipe.enable_attention_slicing() |
|
models[name] = ModelWrapper(pipe, name) |
|
except Exception as e: |
|
logger.error(f"Failed to load {name}: {str(e)}") |
|
continue |
|
|
|
return models |
|
|
|
def generate_all(models, prompt, steps=20): |
|
if not models: |
|
raise ValueError("No models loaded") |
|
|
|
return { |
|
name: model.generate(prompt, steps) |
|
for name, model in models.items() |
|
} |