big_ai / model_pipelines.py
xiehuangbao1122's picture
Update model_pipelines.py
a20155e verified
from diffusers import StableDiffusionPipeline, DiffusionPipeline
import torch
from PIL import Image
import time
import logging
from tqdm import tqdm
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelWrapper:
def __init__(self, pipe, name):
self.pipe = pipe
self.name = name
def generate(self, prompt, steps=20):
logger.info(f"Generating with {self.name}...")
try:
start = time.time()
result = self.pipe(
prompt,
num_inference_steps=steps,
guidance_scale=7.5,
output_type="pil"
).images[0]
logger.info(f"{self.name} generated in {time.time()-start:.2f}s")
return result
except Exception as e:
logger.error(f"{self.name} failed: {str(e)}")
return Image.new("RGB", (512, 512), color="#FF0000")
def load_pipelines():
configs = [
("sd_v1_5", "runwayml/stable-diffusion-v1-5", StableDiffusionPipeline,
{"safety_checker": None}),
("openjourney_v4", "prompthero/openjourney-v4", DiffusionPipeline, {}),
("ldm_256", "CompVis/ldm-text2im-large-256", DiffusionPipeline, {})
]
models = {}
for name, repo, pipe_cls, kwargs in tqdm(configs, desc="Loading models"):
try:
pipe = pipe_cls.from_pretrained(
repo,
torch_dtype=torch.float32,
**kwargs
)
pipe = pipe.to("cpu")
pipe.enable_attention_slicing()
models[name] = ModelWrapper(pipe, name)
except Exception as e:
logger.error(f"Failed to load {name}: {str(e)}")
continue
return models
def generate_all(models, prompt, steps=20):
if not models:
raise ValueError("No models loaded")
return {
name: model.generate(prompt, steps)
for name, model in models.items()
}