Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import os | |
import shlex | |
import subprocess | |
import sys | |
import PIL.Image | |
import torch | |
from diffusers import DPMSolverMultistepScheduler | |
if os.getenv("SYSTEM") == "spaces": | |
with open("patch") as f: | |
subprocess.run(shlex.split("patch -p1"), cwd="multires_textual_inversion", stdin=f) | |
sys.path.insert(0, "multires_textual_inversion") | |
from pipeline import MultiResPipeline, load_learned_concepts | |
class Model: | |
def __init__(self): | |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_id = "runwayml/stable-diffusion-v1-5" | |
if self.device.type == "cpu": | |
pipe = MultiResPipeline.from_pretrained(model_id) | |
else: | |
pipe = MultiResPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16") | |
self.pipe = pipe.to(self.device) | |
self.pipe.scheduler = DPMSolverMultistepScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
num_train_timesteps=1000, | |
trained_betas=None, | |
predict_epsilon=True, | |
thresholding=False, | |
algorithm_type="dpmsolver++", | |
solver_type="midpoint", | |
lower_order_final=True, | |
) | |
self.string_to_param_dict = load_learned_concepts(self.pipe, "textual_inversion_outputs/") | |
def run(self, prompt: str, n_images: int, n_steps: int, seed: int) -> list[PIL.Image.Image]: | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
return self.pipe( | |
[prompt] * n_images, self.string_to_param_dict, num_inference_steps=n_steps, generator=generator | |
) | |