NightRaven109 commited on
Commit
1857471
·
verified ·
1 Parent(s): b138ae1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -33
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import gradio as gr
4
  import spaces
5
  from PIL import Image
 
6
  from diffusers import DiffusionPipeline
7
  from huggingface_hub import snapshot_download
8
  from test_ccsr_tile import load_pipeline
@@ -78,7 +79,7 @@ def initialize_models():
78
  print(f"Error initializing models: {str(e)}")
79
  return False
80
 
81
- @torch.no_grad() # Add no_grad decorator for inference
82
  @spaces.GPU
83
  def process_image(
84
  input_image,
@@ -131,11 +132,11 @@ def process_image(
131
  resize_flag = False
132
  if ori_width < args.process_size//args.upscale or ori_height < args.process_size//args.upscale:
133
  scale = (args.process_size//args.upscale)/min(ori_width, ori_height)
134
- validation_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height)))
135
  resize_flag = True
136
 
137
- validation_image = validation_image.resize((validation_image.size[0]*args.upscale, validation_image.size[1]*args.upscale))
138
- validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
139
  width, height = validation_image.size
140
 
141
  # Generate image
@@ -168,7 +169,36 @@ def process_image(
168
  image = fix_func(image, validation_image)
169
 
170
  if resize_flag:
171
- image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  return image
174
 
@@ -178,7 +208,6 @@ def process_image(
178
  traceback.print_exc()
179
  return None
180
 
181
-
182
  # Define default values
183
  DEFAULT_VALUES = {
184
  "prompt": "clean, texture, high-resolution, 8k",
@@ -194,15 +223,15 @@ DEFAULT_VALUES = {
194
  # Define example data
195
  EXAMPLES = [
196
  [
197
- "examples/1.png", # Input image path
198
- "clean, texture, high-resolution, 8k", # Prompt
199
- "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", # Negative prompt
200
- 3.0, # Guidance scale
201
- 1.0, # Conditioning scale
202
- 6, # Num steps
203
- 42, # Seed
204
- 4, # Upscale factor
205
- "wavelet" # Color fix method
206
  ],
207
  [
208
  "examples/22.png",
@@ -284,16 +313,6 @@ with gr.Blocks(title="Texture Super-Resolution") as demo:
284
  format="png",
285
  show_download_button=True
286
  )
287
-
288
- output_image = gr.Image(
289
- label="Generated Image",
290
- type="pil",
291
- format="png",
292
- elem_id="output_texture",
293
- streaming=False,
294
- show_download_button=True
295
- )
296
-
297
 
298
  # Add examples
299
  gr.Examples(
@@ -305,7 +324,7 @@ with gr.Blocks(title="Texture Super-Resolution") as demo:
305
  ],
306
  outputs=output_image,
307
  fn=process_image,
308
- cache_examples=True # Cache the results for faster loading
309
  )
310
 
311
  # Define submit action
@@ -343,11 +362,5 @@ with gr.Blocks(title="Texture Super-Resolution") as demo:
343
  ]
344
  )
345
 
346
- demo.config = gr.Config(
347
- png_quality=100,
348
- png_compression=0,
349
- file_types=["png"]
350
- )
351
-
352
  if __name__ == "__main__":
353
- demo.launch()
 
3
  import gradio as gr
4
  import spaces
5
  from PIL import Image
6
+ import io
7
  from diffusers import DiffusionPipeline
8
  from huggingface_hub import snapshot_download
9
  from test_ccsr_tile import load_pipeline
 
79
  print(f"Error initializing models: {str(e)}")
80
  return False
81
 
82
+ @torch.no_grad()
83
  @spaces.GPU
84
  def process_image(
85
  input_image,
 
132
  resize_flag = False
133
  if ori_width < args.process_size//args.upscale or ori_height < args.process_size//args.upscale:
134
  scale = (args.process_size//args.upscale)/min(ori_width, ori_height)
135
+ validation_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height)), Image.LANCZOS)
136
  resize_flag = True
137
 
138
+ validation_image = validation_image.resize((validation_image.size[0]*args.upscale, validation_image.size[1]*args.upscale), Image.LANCZOS)
139
+ validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8), Image.LANCZOS)
140
  width, height = validation_image.size
141
 
142
  # Generate image
 
169
  image = fix_func(image, validation_image)
170
 
171
  if resize_flag:
172
+ image = image.resize((ori_width*args.upscale, ori_height*args.upscale), Image.LANCZOS)
173
+
174
+ # Ensure maximum quality output
175
+ if isinstance(image, Image.Image):
176
+ # Convert to RGB mode if not already
177
+ if image.mode != 'RGB':
178
+ image = image.convert('RGB')
179
+
180
+ # Create a new image with white background
181
+ bg = Image.new('RGB', image.size, (255, 255, 255))
182
+ if len(image.split()) > 3: # If image has alpha channel
183
+ bg.paste(image, mask=image.split()[3])
184
+ else:
185
+ bg.paste(image)
186
+
187
+ # Optional: Apply subtle sharpening for better details
188
+ from PIL import ImageEnhance
189
+ enhancer = ImageEnhance.Sharpness(bg)
190
+ image = enhancer.enhance(1.1) # Slight sharpening
191
+
192
+ # Save with maximum quality settings
193
+ output_buffer = io.BytesIO()
194
+ image.save(
195
+ output_buffer,
196
+ format='PNG',
197
+ optimize=False,
198
+ quality=100
199
+ )
200
+ output_buffer.seek(0)
201
+ image = Image.open(output_buffer)
202
 
203
  return image
204
 
 
208
  traceback.print_exc()
209
  return None
210
 
 
211
  # Define default values
212
  DEFAULT_VALUES = {
213
  "prompt": "clean, texture, high-resolution, 8k",
 
223
  # Define example data
224
  EXAMPLES = [
225
  [
226
+ "examples/1.png",
227
+ "clean, texture, high-resolution, 8k",
228
+ "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
229
+ 3.0,
230
+ 1.0,
231
+ 6,
232
+ 42,
233
+ 4,
234
+ "wavelet"
235
  ],
236
  [
237
  "examples/22.png",
 
313
  format="png",
314
  show_download_button=True
315
  )
 
 
 
 
 
 
 
 
 
 
316
 
317
  # Add examples
318
  gr.Examples(
 
324
  ],
325
  outputs=output_image,
326
  fn=process_image,
327
+ cache_examples=True
328
  )
329
 
330
  # Define submit action
 
362
  ]
363
  )
364
 
 
 
 
 
 
 
365
  if __name__ == "__main__":
366
+ demo.launch()