linoyts HF Staff commited on
Commit
be24378
·
verified ·
1 Parent(s): 031d0a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -19
app.py CHANGED
@@ -1,33 +1,106 @@
1
  import gradio as gr
2
  import numpy as np
3
-
4
  import spaces
5
  import torch
6
  import random
7
  from PIL import Image
8
-
9
  from kontext_pipeline import FluxKontextPipeline
10
  from diffusers import FluxTransformer2DModel
11
  from diffusers.utils import load_image
12
-
13
  from huggingface_hub import hf_hub_download
14
 
15
-
16
  kontext_path = hf_hub_download(repo_id="diffusers/kontext", filename="kontext.safetensors")
17
-
18
  MAX_SEED = np.iinfo(np.int32).max
19
-
20
  transformer = FluxTransformer2DModel.from_single_file(kontext_path, torch_dtype=torch.bfloat16)
21
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @spaces.GPU
24
- def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
25
 
26
  if randomize_seed:
27
  seed = random.randint(0, MAX_SEED)
28
-
29
- input_image = input_image.convert("RGB")
30
- # original_width, original_height = input_image.size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # if original_width >= original_height:
33
  # new_width = 1024
@@ -38,15 +111,17 @@ def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5
38
  # new_width = int(original_width * (new_height / original_height))
39
  # new_width = round(new_width / 64) * 64
40
 
41
- #input_image_resized = input_image.resize((new_width, new_height), Image.LANCZOS)
 
42
  image = pipe(
43
- image=input_image,
44
  prompt=prompt,
45
  guidance_scale=guidance_scale,
46
  # width=new_width,
47
  # height=new_height,
48
  generator=torch.Generator().manual_seed(seed),
49
  ).images[0]
 
50
  return image, seed, gr.update(visible=True)
51
 
52
  css="""
@@ -59,12 +134,24 @@ css="""
59
  with gr.Blocks(css=css) as demo:
60
 
61
  with gr.Column(elem_id="col-container"):
62
- gr.Markdown(f"""# FLUX.1 Kontext [dev]
 
63
  """)
64
-
65
  with gr.Row():
66
  with gr.Column():
67
- input_image = gr.Image(label="Upload the image for editing", type="pil")
 
 
 
 
 
 
 
 
 
 
 
 
68
  with gr.Row():
69
  prompt = gr.Text(
70
  label="Prompt",
@@ -74,6 +161,7 @@ with gr.Blocks(css=css) as demo:
74
  container=False,
75
  )
76
  run_button = gr.Button("Run", scale=0)
 
77
  with gr.Accordion("Advanced Settings", open=False):
78
 
79
  seed = gr.Slider(
@@ -99,17 +187,17 @@ with gr.Blocks(css=css) as demo:
99
  reuse_button = gr.Button("Reuse this image", visible=False)
100
 
101
 
102
-
103
  gr.on(
104
  triggers=[run_button.click, prompt.submit],
105
  fn = infer,
106
- inputs = [input_image, prompt, seed, randomize_seed, guidance_scale],
107
  outputs = [result, seed, reuse_button]
108
  )
 
109
  reuse_button.click(
110
- fn = lambda image: image,
111
  inputs = [result],
112
- outputs = [input_image]
113
  )
114
 
115
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import spaces
4
  import torch
5
  import random
6
  from PIL import Image
 
7
  from kontext_pipeline import FluxKontextPipeline
8
  from diffusers import FluxTransformer2DModel
9
  from diffusers.utils import load_image
 
10
  from huggingface_hub import hf_hub_download
11
 
 
12
  kontext_path = hf_hub_download(repo_id="diffusers/kontext", filename="kontext.safetensors")
 
13
  MAX_SEED = np.iinfo(np.int32).max
 
14
  transformer = FluxTransformer2DModel.from_single_file(kontext_path, torch_dtype=torch.bfloat16)
15
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
16
 
17
+ def concatenate_images(images, direction="horizontal"):
18
+ """
19
+ Concatenate multiple PIL images either horizontally or vertically.
20
+
21
+ Args:
22
+ images: List of PIL Images
23
+ direction: "horizontal" or "vertical"
24
+
25
+ Returns:
26
+ PIL Image: Concatenated image
27
+ """
28
+ if not images:
29
+ return None
30
+
31
+ # Filter out None images
32
+ valid_images = [img for img in images if img is not None]
33
+
34
+ if not valid_images:
35
+ return None
36
+
37
+ if len(valid_images) == 1:
38
+ return valid_images[0].convert("RGB")
39
+
40
+ # Convert all images to RGB
41
+ valid_images = [img.convert("RGB") for img in valid_images]
42
+
43
+ if direction == "horizontal":
44
+ # Calculate total width and max height
45
+ total_width = sum(img.width for img in valid_images)
46
+ max_height = max(img.height for img in valid_images)
47
+
48
+ # Create new image
49
+ concatenated = Image.new('RGB', (total_width, max_height), (255, 255, 255))
50
+
51
+ # Paste images
52
+ x_offset = 0
53
+ for img in valid_images:
54
+ # Center image vertically if heights differ
55
+ y_offset = (max_height - img.height) // 2
56
+ concatenated.paste(img, (x_offset, y_offset))
57
+ x_offset += img.width
58
+
59
+ else: # vertical
60
+ # Calculate max width and total height
61
+ max_width = max(img.width for img in valid_images)
62
+ total_height = sum(img.height for img in valid_images)
63
+
64
+ # Create new image
65
+ concatenated = Image.new('RGB', (max_width, total_height), (255, 255, 255))
66
+
67
+ # Paste images
68
+ y_offset = 0
69
+ for img in valid_images:
70
+ # Center image horizontally if widths differ
71
+ x_offset = (max_width - img.width) // 2
72
+ concatenated.paste(img, (x_offset, y_offset))
73
+ y_offset += img.height
74
+
75
+ return concatenated
76
+
77
  @spaces.GPU
78
+ def infer(input_images, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
79
 
80
  if randomize_seed:
81
  seed = random.randint(0, MAX_SEED)
82
+
83
+ # Handle input_images - it could be a single image or a list of images
84
+ if input_images is None:
85
+ raise gr.Error("Please upload at least one image.")
86
+
87
+ # If it's a single image (not a list), convert to list
88
+ if not isinstance(input_images, list):
89
+ input_images = [input_images]
90
+
91
+ # Filter out None images
92
+ valid_images = [img for img in input_images if img is not None]
93
+
94
+ if not valid_images:
95
+ raise gr.Error("Please upload at least one valid image.")
96
+
97
+ # Concatenate images horizontally
98
+ concatenated_image = concatenate_images(valid_images, "horizontal")
99
+
100
+ if concatenated_image is None:
101
+ raise gr.Error("Failed to process the input images.")
102
+
103
+ # original_width, original_height = concatenated_image.size
104
 
105
  # if original_width >= original_height:
106
  # new_width = 1024
 
111
  # new_width = int(original_width * (new_height / original_height))
112
  # new_width = round(new_width / 64) * 64
113
 
114
+ #concatenated_image_resized = concatenated_image.resize((new_width, new_height), Image.LANCZOS)
115
+
116
  image = pipe(
117
+ image=concatenated_image,
118
  prompt=prompt,
119
  guidance_scale=guidance_scale,
120
  # width=new_width,
121
  # height=new_height,
122
  generator=torch.Generator().manual_seed(seed),
123
  ).images[0]
124
+
125
  return image, seed, gr.update(visible=True)
126
 
127
  css="""
 
134
  with gr.Blocks(css=css) as demo:
135
 
136
  with gr.Column(elem_id="col-container"):
137
+ gr.Markdown(f"""# FLUX.1 Kontext [dev] - Multi-Image
138
+ Upload one or multiple images.
139
  """)
 
140
  with gr.Row():
141
  with gr.Column():
142
+ input_images = gr.Gallery(
143
+ label="Upload image(s) for editing",
144
+ show_label=True,
145
+ elem_id="gallery_input",
146
+ columns=3,
147
+ rows=2,
148
+ object_fit="contain",
149
+ height="auto",
150
+ type="pil"
151
+ )
152
+
153
+
154
+
155
  with gr.Row():
156
  prompt = gr.Text(
157
  label="Prompt",
 
161
  container=False,
162
  )
163
  run_button = gr.Button("Run", scale=0)
164
+
165
  with gr.Accordion("Advanced Settings", open=False):
166
 
167
  seed = gr.Slider(
 
187
  reuse_button = gr.Button("Reuse this image", visible=False)
188
 
189
 
 
190
  gr.on(
191
  triggers=[run_button.click, prompt.submit],
192
  fn = infer,
193
+ inputs = [input_images, prompt, seed, randomize_seed, guidance_scale],
194
  outputs = [result, seed, reuse_button]
195
  )
196
+
197
  reuse_button.click(
198
+ fn = lambda image: [image] if image is not None else [], # Convert single image to list for gallery
199
  inputs = [result],
200
+ outputs = [input_images]
201
  )
202
 
203
  demo.launch()