amos1088 commited on
Commit
d8f1f69
·
1 Parent(s): b85795d
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -27,10 +27,26 @@ token = os.getenv("HF_TOKEN")
27
  login(token=token)
28
 
29
  # Model and Pipeline Setup
 
30
  model_path = 'stabilityai/stable-diffusion-3.5-large'
31
  ip_adapter_path = './ip-adapter.bin'
32
  image_encoder_path = "google/siglip-so400m-patch14-384"
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Load transformer and pipeline
35
  transformer = SD3Transformer2DModel.from_pretrained(
36
  model_path, subfolder="transformer", torch_dtype=torch.bfloat16
@@ -48,22 +64,23 @@ pipe.init_ipadapter(
48
 
49
 
50
  @spaces.GPU
51
- def gui_generation(text, num_imgs, width, height):
52
  """
53
  Generate images using Stable Diffusion 3.5
54
  """
55
- images = pipe(
56
- prompt=text,
57
- width=width,
58
- height=height,
59
- num_images_per_prompt=num_imgs,
60
  negative_prompt="lowres, low quality, worst quality",
61
  num_inference_steps=24,
62
  guidance_scale=5.0,
63
  generator=torch.Generator("cuda").manual_seed(42),
64
- ).images
 
 
65
 
66
- return images
67
 
68
 
69
  # Create Gradio interface
@@ -72,19 +89,16 @@ with gr.Blocks() as demo:
72
 
73
  with gr.Row():
74
  prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
75
- number_slider = gr.Slider(1, 30, value=2, step=1, label="Batch size")
76
 
77
  with gr.Row():
78
- width_slider = gr.Slider(256, 1536, value=1024, step=64, label="Width")
79
- height_slider = gr.Slider(256, 1536, value=1024, step=64, label="Height")
80
-
81
- gallery = gr.Gallery(columns=[3], rows=[1], object_fit="contain", height="auto")
82
 
83
  generate_btn = gr.Button("Generate")
84
 
85
  generate_btn.click(
86
  fn=gui_generation,
87
- inputs=[prompt_box, number_slider, width_slider, height_slider],
88
  outputs=gallery
89
  )
90
  demo.launch()
 
27
  login(token=token)
28
 
29
  # Model and Pipeline Setup
30
+
31
  model_path = 'stabilityai/stable-diffusion-3.5-large'
32
  ip_adapter_path = './ip-adapter.bin'
33
  image_encoder_path = "google/siglip-so400m-patch14-384"
34
 
35
+ transformer = SD3Transformer2DModel.from_pretrained(
36
+ model_path, subfolder="transformer", torch_dtype=torch.bfloat16
37
+ )
38
+
39
+ pipe = StableDiffusion3Pipeline.from_pretrained(
40
+ model_path, transformer=transformer, torch_dtype=torch.bfloat16
41
+ ).to("cuda")
42
+
43
+ pipe.init_ipadapter(
44
+ ip_adapter_path=ip_adapter_path,
45
+ image_encoder_path=image_encoder_path,
46
+ nb_token=64,
47
+ )
48
+
49
+
50
  # Load transformer and pipeline
51
  transformer = SD3Transformer2DModel.from_pretrained(
52
  model_path, subfolder="transformer", torch_dtype=torch.bfloat16
 
64
 
65
 
66
  @spaces.GPU
67
+ def gui_generation(prompt, ref_img):
68
  """
69
  Generate images using Stable Diffusion 3.5
70
  """
71
+ image = pipe(
72
+ width=1024,
73
+ height=1024,
74
+ prompt=prompt,
 
75
  negative_prompt="lowres, low quality, worst quality",
76
  num_inference_steps=24,
77
  guidance_scale=5.0,
78
  generator=torch.Generator("cuda").manual_seed(42),
79
+ clip_image=ref_img,
80
+ ipadapter_scale=0.5,
81
+ ).images[0]
82
 
83
+ return image
84
 
85
 
86
  # Create Gradio interface
 
89
 
90
  with gr.Row():
91
  prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
 
92
 
93
  with gr.Row():
94
+ ref_img = gr.Image(type="pil", label="Upload Reference Image")
95
+ gallery = gr.Image(type="pil", label="Generated Image")
 
 
96
 
97
  generate_btn = gr.Button("Generate")
98
 
99
  generate_btn.click(
100
  fn=gui_generation,
101
+ inputs=[prompt_box, ref_img],
102
  outputs=gallery
103
  )
104
  demo.launch()