Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,046 Bytes
25ad706 29356cb f1ee166 92c37e9 29356cb 337146c 25ad706 0782bc0 25ad706 0782bc0 25ad706 29356cb 337146c 4a66938 337146c 25ad706 0782bc0 4a66938 25ad706 29356cb 25ad706 29356cb 4a66938 25ad706 29356cb e2d6adc 4a66938 f1ee166 29356cb 13a4c81 29356cb 337146c 29356cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gradio as gr
import spaces
def resize_image(image, max_size=2048):
width, height = image.size
if width > max_size or height > max_size:
aspect_ratio = width / height
if width > height:
new_width = max_size
new_height = int(new_width / aspect_ratio)
else:
new_height = max_size
new_width = int(new_height * aspect_ratio)
image = image.resize((new_width, new_height), Image.LANCZOS)
return image
def split_image(image, chunk_size=512):
width, height = image.size
chunks = []
for y in range(0, height, chunk_size):
for x in range(0, width, chunk_size):
chunk = image.crop((x, y, min(x + chunk_size, width), min(y + chunk_size, height)))
chunks.append((chunk, x, y))
return chunks
def stitch_image(chunks, original_size):
result = Image.new('RGB', original_size)
for img, x, y in chunks:
result.paste(img, (x, y))
return result
def upscale_chunk(chunk, model, processor, device):
inputs = processor(chunk, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
output = np.moveaxis(output, source=0, destination=-1)
output_image = (output * 255.0).round().astype(np.uint8)
return Image.fromarray(output_image)
@spaces.GPU(duration=120)
def main(image, model_choice, save_as_jpg=True):
# Resize the input image
image = resize_image(image)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_paths = {
"Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
"PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
}
processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice]).to(device)
# Split the image into chunks
chunks = split_image(image)
# Process each chunk
upscaled_chunks = []
for chunk, x, y in chunks:
upscaled_chunk = upscale_chunk(chunk, model, processor, device)
# Remove 32 pixels from bottom and right edges
upscaled_chunk = upscaled_chunk.crop((0, 0, upscaled_chunk.width - 32, upscaled_chunk.height - 32))
upscaled_chunks.append((upscaled_chunk, x * 4, y * 4)) # Multiply coordinates by 4 due to 4x upscaling
# Stitch the chunks back together
final_size = (image.width * 4 - 32, image.height * 4 - 32) # Adjust for removed pixels
upscaled_image = stitch_image(upscaled_chunks, final_size)
if save_as_jpg:
upscaled_image.save("upscaled_image.jpg", quality=95)
return "upscaled_image.jpg"
else:
upscaled_image.save("upscaled_image.png")
return "upscaled_image.png"
def gradio_interface(image, model_choice, save_as_jpg):
try:
result = main(image, model_choice, save_as_jpg)
return result, None
except Exception as e:
return None, str(e)
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(
choices=["PSNR Match (Recommended)", "Pixel Perfect"],
label="Select Model",
value="PSNR Match (Recommended)"
),
gr.Checkbox(value=True, label="Save as JPEG"),
],
outputs=[
gr.File(label="Download Upscaled Image"),
gr.Textbox(label="Error Message", visible=True)
],
title="Image Upscaler",
description="Upload an image, select a model, and upscale it. Images larger than 2048x2048 will be resized while maintaining aspect ratio. The image will be processed in 512x512 pixel chunks for efficient handling.",
)
interface.launch() |