dgoot commited on
Commit
b7a0604
·
1 Parent(s): e80303a

Fix pickling issue with ZeroGPU tuning

Browse files
Files changed (1) hide show
  1. app.py +42 -34
app.py CHANGED
@@ -13,14 +13,51 @@ models = [
13
  DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0"
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @spaces.GPU
17
- def gpu(fn):
18
- return fn()
19
 
20
 
21
  @spaces.GPU(duration=180)
22
- def gpu_3min(fn):
23
- return fn()
24
 
25
 
26
  @logger.catch(reraise=True)
@@ -38,38 +75,9 @@ def generate(
38
  # Downscale the image
39
  init_image.thumbnail((1024, 1024))
40
 
41
- def progress_callback(pipe, step_index, timestep, callback_kwargs):
42
- logger.trace(
43
- f"Callback: {dict(num_timesteps=pipe.num_timesteps, step_index=step_index, timestep=timestep)}"
44
- )
45
- progress((step_index + 1, pipe.num_timesteps))
46
- return callback_kwargs
47
-
48
- pipeline_type = (
49
- StableDiffusionInstructPix2PixPipeline
50
- if model == "timbrooks/instruct-pix2pix"
51
- else AutoPipelineForImage2Image
52
- )
53
-
54
- logger.debug(f"Loading pipeline: {dict(model=model)}")
55
- pipe = pipeline_type.from_pretrained(model)
56
-
57
- logger.debug(f"Generating image: {dict(prompt=prompt)}")
58
- additional_args = (
59
- {} if model == "timbrooks/instruct-pix2pix" else dict(strength=strength)
60
- )
61
-
62
  gpu_runner = gpu_3min if model == "timbrooks/instruct-pix2pix" else gpu
63
 
64
- images = gpu_runner(
65
- lambda: pipe.to("cuda")(
66
- prompt=prompt,
67
- image=init_image,
68
- callback_on_step_end=progress_callback,
69
- **additional_args,
70
- ).images
71
- )
72
- return images[0]
73
 
74
 
75
  demo = gr.Interface(
 
13
  DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0"
14
 
15
 
16
+ def generate_image(
17
+ model: str,
18
+ prompt: str,
19
+ init_image: Image.Image,
20
+ strength: float,
21
+ progress,
22
+ ):
23
+ pipeline_type = (
24
+ StableDiffusionInstructPix2PixPipeline
25
+ if model == "timbrooks/instruct-pix2pix"
26
+ else AutoPipelineForImage2Image
27
+ )
28
+
29
+ logger.debug(f"Loading pipeline: {dict(model=model)}")
30
+ pipe = pipeline_type.from_pretrained(model).to("cuda")
31
+
32
+ logger.debug(f"Generating image: {dict(prompt=prompt)}")
33
+ additional_args = (
34
+ {} if model == "timbrooks/instruct-pix2pix" else dict(strength=strength)
35
+ )
36
+
37
+ def progress_callback(pipe, step_index, timestep, callback_kwargs):
38
+ logger.trace(
39
+ f"Callback: {dict(num_timesteps=pipe.num_timesteps, step_index=step_index, timestep=timestep)}"
40
+ )
41
+ progress((step_index + 1, pipe.num_timesteps))
42
+ return callback_kwargs
43
+
44
+ images = pipe(
45
+ prompt=prompt,
46
+ image=init_image,
47
+ callback_on_step_end=progress_callback,
48
+ **additional_args,
49
+ ).images
50
+ return images[0]
51
+
52
+
53
  @spaces.GPU
54
+ def gpu(*args, **kwargs):
55
+ return generate_image(*args, **kwargs)
56
 
57
 
58
  @spaces.GPU(duration=180)
59
+ def gpu_3min(*args, **kwargs):
60
+ return generate_image(*args, **kwargs)
61
 
62
 
63
  @logger.catch(reraise=True)
 
75
  # Downscale the image
76
  init_image.thumbnail((1024, 1024))
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  gpu_runner = gpu_3min if model == "timbrooks/instruct-pix2pix" else gpu
79
 
80
+ return gpu_runner(model, prompt, init_image, strength, progress)
 
 
 
 
 
 
 
 
81
 
82
 
83
  demo = gr.Interface(