Update model_pipelines.py
Browse files- model_pipelines.py +54 -38
model_pipelines.py
CHANGED
@@ -2,47 +2,63 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline
|
|
2 |
import torch
|
3 |
from PIL import Image
|
4 |
import time
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
models = {
|
14 |
-
"sd_v1_5": StableDiffusionPipeline.from_pretrained(
|
15 |
-
"runwayml/stable-diffusion-v1-5",
|
16 |
-
torch_dtype=torch_dtype,
|
17 |
-
safety_checker=None,
|
18 |
-
requires_safety_checker=False
|
19 |
-
).to(device),
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
|
34 |
-
for
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
return models
|
38 |
|
39 |
-
def generate_all(
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
results[name] = result
|
48 |
-
return results
|
|
|
2 |
import torch
|
3 |
from PIL import Image
|
4 |
import time
|
5 |
+
import logging
|
6 |
+
from tqdm import tqdm
|
7 |
|
8 |
+
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
|
11 |
+
class ModelWrapper:
|
12 |
+
def __init__(self, pipe, name):
|
13 |
+
self.pipe = pipe
|
14 |
+
self.name = name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
def generate(self, prompt, steps=20):
|
17 |
+
logger.info(f"Generating with {self.name}...")
|
18 |
+
try:
|
19 |
+
start = time.time()
|
20 |
+
result = self.pipe(
|
21 |
+
prompt,
|
22 |
+
num_inference_steps=steps,
|
23 |
+
guidance_scale=7.5,
|
24 |
+
output_type="pil"
|
25 |
+
).images[0]
|
26 |
+
logger.info(f"{self.name} generated in {time.time()-start:.2f}s")
|
27 |
+
return result
|
28 |
+
except Exception as e:
|
29 |
+
logger.error(f"{self.name} failed: {str(e)}")
|
30 |
+
return Image.new("RGB", (512, 512), color="#FF0000")
|
31 |
+
|
32 |
+
def load_pipelines():
|
33 |
+
configs = [
|
34 |
+
("sd_v1_5", "runwayml/stable-diffusion-v1-5", StableDiffusionPipeline,
|
35 |
+
{"safety_checker": None}),
|
36 |
+
("openjourney_v4", "prompthero/openjourney-v4", DiffusionPipeline, {}),
|
37 |
+
("ldm_256", "CompVis/ldm-text2im-large-256", DiffusionPipeline, {})
|
38 |
+
]
|
39 |
|
40 |
+
models = {}
|
41 |
+
for name, repo, pipe_cls, kwargs in tqdm(configs, desc="Loading models"):
|
42 |
+
try:
|
43 |
+
pipe = pipe_cls.from_pretrained(
|
44 |
+
repo,
|
45 |
+
torch_dtype=torch.float32,
|
46 |
+
**kwargs
|
47 |
+
)
|
48 |
+
pipe = pipe.to("cpu")
|
49 |
+
pipe.enable_attention_slicing()
|
50 |
+
models[name] = ModelWrapper(pipe, name)
|
51 |
+
except Exception as e:
|
52 |
+
logger.error(f"Failed to load {name}: {str(e)}")
|
53 |
+
continue
|
54 |
+
|
55 |
return models
|
56 |
|
57 |
+
def generate_all(models, prompt, steps=20):
|
58 |
+
if not models:
|
59 |
+
raise ValueError("No models loaded")
|
60 |
+
|
61 |
+
return {
|
62 |
+
name: model.generate(prompt, steps)
|
63 |
+
for name, model in models.items()
|
64 |
+
}
|
|
|
|