import torch from PIL import Image import numpy as np from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution import gradio as gr import spaces def resize_image(image, max_size=2048): width, height = image.size if width > max_size or height > max_size: aspect_ratio = width / height if width > height: new_width = max_size new_height = int(new_width / aspect_ratio) else: new_height = max_size new_width = int(new_height * aspect_ratio) image = image.resize((new_width, new_height), Image.LANCZOS) return image def split_image(image, chunk_size=512): width, height = image.size chunks = [] for y in range(0, height, chunk_size): for x in range(0, width, chunk_size): chunk = image.crop((x, y, min(x + chunk_size, width), min(y + chunk_size, height))) chunks.append((chunk, x, y)) return chunks def stitch_image(chunks, original_size): result = Image.new('RGB', original_size) for img, x, y in chunks: result.paste(img, (x, y)) return result def upscale_chunk(chunk, model, processor, device): inputs = processor(chunk, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy() output = np.moveaxis(output, source=0, destination=-1) output_image = (output * 255.0).round().astype(np.uint8) return Image.fromarray(output_image) @spaces.GPU def main(image, model_choice, save_as_jpg=True, use_tiling=True, auto_cpu=True): # Resize the input image image = resize_image(image) device = torch.device("cuda" if torch.cuda.is_available() and not auto_cpu else "cpu") model_paths = { "Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64", "PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" } processor = AutoImageProcessor.from_pretrained(model_paths[model_choice]) model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice]).to(device) if use_tiling: # Split the image into chunks chunks = split_image(image) # Process each chunk upscaled_chunks = [] for chunk, x, y in chunks: upscaled_chunk = upscale_chunk(chunk, model, processor, device) # Remove 32 pixels from bottom and right edges upscaled_chunk = upscaled_chunk.crop((0, 0, upscaled_chunk.width - 32, upscaled_chunk.height - 32)) upscaled_chunks.append((upscaled_chunk, x * 4, y * 4)) # Multiply coordinates by 4 due to 4x upscaling # Stitch the chunks back together final_size = (image.width * 4 - 32, image.height * 4 - 32) # Adjust for removed pixels upscaled_image = stitch_image(upscaled_chunks, final_size) else: # Process the entire image at once upscaled_image = upscale_chunk(image, model, processor, device) if save_as_jpg: upscaled_image.save("upscaled_image.jpg", quality=95) return "upscaled_image.jpg" else: upscaled_image.save("upscaled_image.png") return "upscaled_image.png" def gradio_interface(image, model_choice, save_as_jpg, use_tiling, auto_cpu): try: result = main(image, model_choice, save_as_jpg, use_tiling, auto_cpu) return result, None except Exception as e: return None, str(e) interface = gr.Interface( fn=gradio_interface, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Dropdown( choices=["PSNR Match (Recommended)", "Pixel Perfect"], label="Select Model", value="PSNR Match (Recommended)" ), gr.Checkbox(value=True, label="Save as JPEG"), gr.Checkbox(value=True, label="Use Tiling"), gr.Checkbox(value=True, label="Auto CPU"), ], outputs=[ gr.File(label="Download Upscaled Image"), gr.Textbox(label="Error Message", visible=True) ], title="Image Upscaler", description="Upload an image, select a model, and upscale it. Images larger than 2048x2048 will be resized while maintaining aspect ratio. Use tiling for efficient processing of large images. Auto CPU will use CPU if GPU is not available.", ) interface.launch()