from diffusers import StableDiffusionPipeline, DiffusionPipeline import torch from PIL import Image import time import logging from tqdm import tqdm logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO) logger = logging.getLogger(__name__) class ModelWrapper: def __init__(self, pipe, name): self.pipe = pipe self.name = name def generate(self, prompt, steps=20): logger.info(f"Generating with {self.name}...") try: start = time.time() result = self.pipe( prompt, num_inference_steps=steps, guidance_scale=7.5, output_type="pil" ).images[0] logger.info(f"{self.name} generated in {time.time()-start:.2f}s") return result except Exception as e: logger.error(f"{self.name} failed: {str(e)}") return Image.new("RGB", (512, 512), color="#FF0000") def load_pipelines(): configs = [ ("sd_v1_5", "runwayml/stable-diffusion-v1-5", StableDiffusionPipeline, {"safety_checker": None}), ("openjourney_v4", "prompthero/openjourney-v4", DiffusionPipeline, {}), ("ldm_256", "CompVis/ldm-text2im-large-256", DiffusionPipeline, {}) ] models = {} for name, repo, pipe_cls, kwargs in tqdm(configs, desc="Loading models"): try: pipe = pipe_cls.from_pretrained( repo, torch_dtype=torch.float32, **kwargs ) pipe = pipe.to("cpu") pipe.enable_attention_slicing() models[name] = ModelWrapper(pipe, name) except Exception as e: logger.error(f"Failed to load {name}: {str(e)}") continue return models def generate_all(models, prompt, steps=20): if not models: raise ValueError("No models loaded") return { name: model.generate(prompt, steps) for name, model in models.items() }