File size: 1,714 Bytes
9a5dc44
 
 
85f94cd
9a5dc44
 
 
 
 
 
 
d818369
 
 
9a5dc44
d818369
9a5dc44
 
 
 
 
 
d818369
 
 
2df1835
9a5dc44
d818369
9a5dc44
 
 
 
d818369
9a5dc44
 
 
 
d818369
 
9a5dc44
 
d818369
9a5dc44
d818369
9a5dc44
d818369
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import annotations

import os
import shlex
import subprocess
import sys

import PIL.Image
import torch
from diffusers import DPMSolverMultistepScheduler

if os.getenv("SYSTEM") == "spaces":
    with open("patch") as f:
        subprocess.run(shlex.split("patch -p1"), cwd="multires_textual_inversion", stdin=f)

sys.path.insert(0, "multires_textual_inversion")

from pipeline import MultiResPipeline, load_learned_concepts


class Model:
    def __init__(self):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model_id = "runwayml/stable-diffusion-v1-5"
        if self.device.type == "cpu":
            pipe = MultiResPipeline.from_pretrained(model_id)
        else:
            pipe = MultiResPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16")
        self.pipe = pipe.to(self.device)
        self.pipe.scheduler = DPMSolverMultistepScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            num_train_timesteps=1000,
            trained_betas=None,
            predict_epsilon=True,
            thresholding=False,
            algorithm_type="dpmsolver++",
            solver_type="midpoint",
            lower_order_final=True,
        )
        self.string_to_param_dict = load_learned_concepts(self.pipe, "textual_inversion_outputs/")

    def run(self, prompt: str, n_images: int, n_steps: int, seed: int) -> list[PIL.Image.Image]:
        generator = torch.Generator(device=self.device).manual_seed(seed)
        return self.pipe(
            [prompt] * n_images, self.string_to_param_dict, num_inference_steps=n_steps, generator=generator
        )