Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,343 Bytes
25ad706 29356cb f1ee166 92c37e9 7129683 29356cb 337146c 25ad706 0782bc0 25ad706 0782bc0 25ad706 29356cb ce069dd 65b549e 80950c2 337146c 7129683 0782bc0 4a66938 25ad706 29356cb 9fbd930 ce069dd 9fbd930 ce069dd 9fbd930 ce069dd 25ad706 80950c2 7129683 25ad706 7129683 25ad706 7129683 29356cb 7129683 25ad706 80950c2 25ad706 29356cb e2d6adc 4a66938 f1ee166 9fbd930 29356cb 13a4c81 29356cb 7129683 29356cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gradio as gr
import spaces
import os
def resize_image(image, max_size=2048):
width, height = image.size
if width > max_size or height > max_size:
aspect_ratio = width / height
if width > height:
new_width = max_size
new_height = int(new_width / aspect_ratio)
else:
new_height = max_size
new_width = int(new_height * aspect_ratio)
image = image.resize((new_width, new_height), Image.LANCZOS)
return image
def split_image(image, chunk_size=512):
width, height = image.size
chunks = []
for y in range(0, height, chunk_size):
for x in range(0, width, chunk_size):
chunk = image.crop((x, y, min(x + chunk_size, width), min(y + chunk_size, height)))
chunks.append((chunk, x, y))
return chunks
def stitch_image(chunks, original_size):
result = Image.new('RGB', original_size)
for img, x, y in chunks:
result.paste(img, (x, y))
return result
def upscale_chunk(chunk, model, processor, device):
inputs = processor(chunk, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
output = np.moveaxis(output, source=0, destination=-1)
output_image = (output * 255.0).round().astype(np.uint8)
return Image.fromarray(output_image)
def remove_boundary(image, boundary=32):
return image.crop((0, 0, image.width - boundary, image.height - boundary))
@spaces.GPU
def main(image, original_filename, model_choice, save_as_jpg=True, use_tiling=True):
image = resize_image(image)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_paths = {
"Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
"PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
}
processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice]).to(device)
if use_tiling:
chunks = split_image(image)
upscaled_chunks = []
for chunk, x, y in chunks:
upscaled_chunk = upscale_chunk(chunk, model, processor, device)
upscaled_chunk = remove_boundary(upscaled_chunk)
upscaled_chunks.append((upscaled_chunk, x * 4, y * 4))
upscaled_image = stitch_image(upscaled_chunks, (image.width * 4, image.height * 4))
else:
upscaled_image = upscale_chunk(image, model, processor, device)
upscaled_image = remove_boundary(upscaled_image)
original_basename = os.path.splitext(original_filename)[0] if original_filename else "image"
output_filename = f"{original_basename}_upscaled"
if save_as_jpg:
output_filename += ".jpg"
upscaled_image.save(output_filename, quality=95)
else:
output_filename += ".png"
upscaled_image.save(output_filename)
return output_filename
def gradio_interface(image, model_choice, save_as_jpg, use_tiling):
try:
original_filename = getattr(image, 'name', 'image')
result = main(image, original_filename, model_choice, save_as_jpg, use_tiling)
return result, None
except Exception as e:
return None, str(e)
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(
choices=["PSNR Match (Recommended)", "Pixel Perfect"],
label="Select Model",
value="PSNR Match (Recommended)"
),
gr.Checkbox(value=True, label="Save as JPEG"),
gr.Checkbox(value=True, label="Use Tiling"),
],
outputs=[
gr.File(label="Download Upscaled Image"),
gr.Textbox(label="Error Message", visible=True)
],
title="Image Upscaler",
description="Upload an image, select a model, and upscale it. Images larger than 2048x2048 will be resized while maintaining aspect ratio. Use tiling for efficient processing of large images.",
)
interface.launch() |