amos1088 commited on
Commit
15330e3
·
1 Parent(s): d8f1f69
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -21,6 +21,7 @@ import gradio as gr
21
  import torch
22
  from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
23
  import os
 
24
  import spaces
25
  from huggingface_hub import login
26
  token = os.getenv("HF_TOKEN")
@@ -62,23 +63,33 @@ pipe.init_ipadapter(
62
  nb_token=64,
63
  )
64
 
65
-
66
  @spaces.GPU
67
  def gui_generation(prompt, ref_img):
68
  """
69
  Generate images using Stable Diffusion 3.5
70
  """
71
- image = pipe(
72
- width=1024,
73
- height=1024,
74
- prompt=prompt,
75
- negative_prompt="lowres, low quality, worst quality",
76
- num_inference_steps=24,
77
- guidance_scale=5.0,
78
- generator=torch.Generator("cuda").manual_seed(42),
79
- clip_image=ref_img,
80
- ipadapter_scale=0.5,
81
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  return image
84
 
 
21
  import torch
22
  from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
23
  import os
24
+ import numpy as np
25
  import spaces
26
  from huggingface_hub import login
27
  token = os.getenv("HF_TOKEN")
 
63
  nb_token=64,
64
  )
65
 
 
66
  @spaces.GPU
67
  def gui_generation(prompt, ref_img):
68
  """
69
  Generate images using Stable Diffusion 3.5
70
  """
71
+ # Ensure reference image is in the correct format
72
+ if ref_img:
73
+ ref_img = ref_img.convert("RGB")
74
+ ref_img_tensor = torch.tensor(
75
+ np.array(ref_img), dtype=torch.bfloat16, device="cuda"
76
+ )
77
+ else:
78
+ raise ValueError("Reference image is required.")
79
+
80
+ # Ensure the pipeline runs with correct dtype and device
81
+ with torch.autocast("cuda", dtype=torch.bfloat16):
82
+ image = pipe(
83
+ width=1024,
84
+ height=1024,
85
+ prompt=prompt,
86
+ negative_prompt="lowres, low quality, worst quality",
87
+ num_inference_steps=24,
88
+ guidance_scale=5.0,
89
+ generator=torch.Generator("cuda").manual_seed(42),
90
+ clip_image=ref_img_tensor,
91
+ ipadapter_scale=0.5,
92
+ ).images[0]
93
 
94
  return image
95