dgoot commited on
Commit
2d7f410
·
1 Parent(s): b7a0604

Cache pipeline before running on GPU

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -12,6 +12,18 @@ models = [
12
  ]
13
  DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0"
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def generate_image(
17
  model: str,
@@ -20,14 +32,8 @@ def generate_image(
20
  strength: float,
21
  progress,
22
  ):
23
- pipeline_type = (
24
- StableDiffusionInstructPix2PixPipeline
25
- if model == "timbrooks/instruct-pix2pix"
26
- else AutoPipelineForImage2Image
27
- )
28
-
29
  logger.debug(f"Loading pipeline: {dict(model=model)}")
30
- pipe = pipeline_type.from_pretrained(model).to("cuda")
31
 
32
  logger.debug(f"Generating image: {dict(prompt=prompt)}")
33
  additional_args = (
@@ -75,6 +81,13 @@ def generate(
75
  # Downscale the image
76
  init_image.thumbnail((1024, 1024))
77
 
 
 
 
 
 
 
 
78
  gpu_runner = gpu_3min if model == "timbrooks/instruct-pix2pix" else gpu
79
 
80
  return gpu_runner(model, prompt, init_image, strength, progress)
 
12
  ]
13
  DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0"
14
 
15
+ loaded_models: set[str] = set()
16
+
17
+
18
+ def load_pipeline(model):
19
+ pipeline_type = (
20
+ StableDiffusionInstructPix2PixPipeline
21
+ if model == "timbrooks/instruct-pix2pix"
22
+ else AutoPipelineForImage2Image
23
+ )
24
+
25
+ return pipeline_type.from_pretrained(model)
26
+
27
 
28
  def generate_image(
29
  model: str,
 
32
  strength: float,
33
  progress,
34
  ):
 
 
 
 
 
 
35
  logger.debug(f"Loading pipeline: {dict(model=model)}")
36
+ pipe = load_pipeline(model).to("cuda")
37
 
38
  logger.debug(f"Generating image: {dict(prompt=prompt)}")
39
  additional_args = (
 
81
  # Downscale the image
82
  init_image.thumbnail((1024, 1024))
83
 
84
+ # Cache the model files for the pipeline
85
+ if model not in loaded_models:
86
+ logger.debug(f"Caching pipeline: {dict(model=model)}")
87
+
88
+ load_pipeline(model)
89
+ loaded_models.add(model)
90
+
91
  gpu_runner = gpu_3min if model == "timbrooks/instruct-pix2pix" else gpu
92
 
93
  return gpu_runner(model, prompt, init_image, strength, progress)