File size: 2,046 Bytes
9d3dabe
8aca2a0
9d3dabe
 
a20155e
 
8aca2a0
a20155e
 
b1f3d37
a20155e
 
 
 
9d3dabe
a20155e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d3dabe
a20155e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d3dabe
8aca2a0
a20155e
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from diffusers import StableDiffusionPipeline, DiffusionPipeline
import torch
from PIL import Image
import time
import logging
from tqdm import tqdm

logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelWrapper:
    def __init__(self, pipe, name):
        self.pipe = pipe
        self.name = name
        
    def generate(self, prompt, steps=20):
        logger.info(f"Generating with {self.name}...")
        try:
            start = time.time()
            result = self.pipe(
                prompt,
                num_inference_steps=steps,
                guidance_scale=7.5,
                output_type="pil"
            ).images[0]
            logger.info(f"{self.name} generated in {time.time()-start:.2f}s")
            return result
        except Exception as e:
            logger.error(f"{self.name} failed: {str(e)}")
            return Image.new("RGB", (512, 512), color="#FF0000")

def load_pipelines():
    configs = [
        ("sd_v1_5", "runwayml/stable-diffusion-v1-5", StableDiffusionPipeline, 
         {"safety_checker": None}),
        ("openjourney_v4", "prompthero/openjourney-v4", DiffusionPipeline, {}),
        ("ldm_256", "CompVis/ldm-text2im-large-256", DiffusionPipeline, {})
    ]
    
    models = {}
    for name, repo, pipe_cls, kwargs in tqdm(configs, desc="Loading models"):
        try:
            pipe = pipe_cls.from_pretrained(
                repo,
                torch_dtype=torch.float32,
                **kwargs
            )
            pipe = pipe.to("cpu")
            pipe.enable_attention_slicing()
            models[name] = ModelWrapper(pipe, name)
        except Exception as e:
            logger.error(f"Failed to load {name}: {str(e)}")
            continue
            
    return models

def generate_all(models, prompt, steps=20):
    if not models:
        raise ValueError("No models loaded")
        
    return {
        name: model.generate(prompt, steps)
        for name, model in models.items()
    }