NightRaven109 commited on
Commit
83686fb
·
verified ·
1 Parent(s): 9a3a9c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -15
app.py CHANGED
@@ -144,14 +144,23 @@ def process_image(
144
  generator.manual_seed(seed)
145
 
146
  # Process input image
147
- input_pil = Image.fromarray(input_image)
148
- width, height = input_pil.size
149
 
150
- # Resize image
151
- target_width = width * upscale_factor
152
- target_height = height * upscale_factor
153
- target_width = target_width - (target_width % 8)
154
- target_height = target_height - (target_height % 8)
 
 
 
 
 
 
 
 
 
155
 
156
  # Move pipeline to GPU for processing
157
  pipeline.to(accelerator.device)
@@ -164,12 +173,12 @@ def process_image(
164
  False, # tile_diffusion
165
  None, # tile_diffusion_size
166
  None, # tile_diffusion_stride
167
- prompt, # validation_prompt / added_prompt
168
- input_pil, # validation_image
169
  num_inference_steps=num_inference_steps,
170
  generator=generator,
171
- height=target_height,
172
- width=target_width,
173
  guidance_scale=guidance_scale,
174
  negative_prompt=negative_prompt,
175
  conditioning_scale=conditioning_scale,
@@ -178,18 +187,21 @@ def process_image(
178
  use_vae_encode_condition=False
179
  )
180
 
181
- generated_image = output.images[0]
182
-
183
  # Apply color fixing if specified
184
  if color_fix_method != "none":
185
  fix_func = wavelet_color_fix if color_fix_method == "wavelet" else adain_color_fix
186
- generated_image = fix_func(generated_image, input_pil)
 
 
 
187
 
188
  # Move pipeline back to CPU
189
  pipeline.to("cpu")
190
  torch.cuda.empty_cache()
191
 
192
- return generated_image
193
 
194
  except Exception as e:
195
  print(f"Error processing image: {str(e)}")
 
144
  generator.manual_seed(seed)
145
 
146
  # Process input image
147
+ validation_image = Image.fromarray(input_image)
148
+ ori_width, ori_height = validation_image.size
149
 
150
+ # Resize logic from original script
151
+ resize_flag = False
152
+ rscale = upscale_factor
153
+ process_size = 512 # Same as args.process_size in original
154
+
155
+ if ori_width < process_size//rscale or ori_height < process_size//rscale:
156
+ scale = (process_size//rscale)/min(ori_width, ori_height)
157
+ tmp_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height)))
158
+ validation_image = tmp_image
159
+ resize_flag = True
160
+
161
+ validation_image = validation_image.resize((validation_image.size[0]*rscale, validation_image.size[1]*rscale))
162
+ validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
163
+ width, height = validation_image.size
164
 
165
  # Move pipeline to GPU for processing
166
  pipeline.to(accelerator.device)
 
173
  False, # tile_diffusion
174
  None, # tile_diffusion_size
175
  None, # tile_diffusion_stride
176
+ prompt,
177
+ validation_image,
178
  num_inference_steps=num_inference_steps,
179
  generator=generator,
180
+ height=height,
181
+ width=width,
182
  guidance_scale=guidance_scale,
183
  negative_prompt=negative_prompt,
184
  conditioning_scale=conditioning_scale,
 
187
  use_vae_encode_condition=False
188
  )
189
 
190
+ image = output.images[0]
191
+
192
  # Apply color fixing if specified
193
  if color_fix_method != "none":
194
  fix_func = wavelet_color_fix if color_fix_method == "wavelet" else adain_color_fix
195
+ image = fix_func(image, validation_image)
196
+
197
+ if resize_flag:
198
+ image = image.resize((ori_width*rscale, ori_height*rscale))
199
 
200
  # Move pipeline back to CPU
201
  pipeline.to("cpu")
202
  torch.cuda.empty_cache()
203
 
204
+ return image
205
 
206
  except Exception as e:
207
  print(f"Error processing image: {str(e)}")