File size: 2,046 Bytes
9d3dabe 8aca2a0 9d3dabe a20155e 8aca2a0 a20155e b1f3d37 a20155e 9d3dabe a20155e 9d3dabe a20155e 9d3dabe 8aca2a0 a20155e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from diffusers import StableDiffusionPipeline, DiffusionPipeline
import torch
from PIL import Image
import time
import logging
from tqdm import tqdm
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelWrapper:
def __init__(self, pipe, name):
self.pipe = pipe
self.name = name
def generate(self, prompt, steps=20):
logger.info(f"Generating with {self.name}...")
try:
start = time.time()
result = self.pipe(
prompt,
num_inference_steps=steps,
guidance_scale=7.5,
output_type="pil"
).images[0]
logger.info(f"{self.name} generated in {time.time()-start:.2f}s")
return result
except Exception as e:
logger.error(f"{self.name} failed: {str(e)}")
return Image.new("RGB", (512, 512), color="#FF0000")
def load_pipelines():
configs = [
("sd_v1_5", "runwayml/stable-diffusion-v1-5", StableDiffusionPipeline,
{"safety_checker": None}),
("openjourney_v4", "prompthero/openjourney-v4", DiffusionPipeline, {}),
("ldm_256", "CompVis/ldm-text2im-large-256", DiffusionPipeline, {})
]
models = {}
for name, repo, pipe_cls, kwargs in tqdm(configs, desc="Loading models"):
try:
pipe = pipe_cls.from_pretrained(
repo,
torch_dtype=torch.float32,
**kwargs
)
pipe = pipe.to("cpu")
pipe.enable_attention_slicing()
models[name] = ModelWrapper(pipe, name)
except Exception as e:
logger.error(f"Failed to load {name}: {str(e)}")
continue
return models
def generate_all(models, prompt, steps=20):
if not models:
raise ValueError("No models loaded")
return {
name: model.generate(prompt, steps)
for name, model in models.items()
} |