File size: 4,343 Bytes
25ad706
29356cb
 
 
f1ee166
92c37e9
7129683
29356cb
337146c
 
 
 
 
 
 
 
 
 
 
 
 
25ad706
0782bc0
25ad706
 
 
 
 
 
0782bc0
25ad706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29356cb
ce069dd
 
 
65b549e
80950c2
337146c
 
7129683
0782bc0
4a66938
 
 
 
 
 
25ad706
29356cb
9fbd930
 
 
 
 
ce069dd
 
9fbd930
ce069dd
9fbd930
 
ce069dd
25ad706
80950c2
 
7129683
25ad706
7129683
 
25ad706
7129683
 
 
 
29356cb
7129683
25ad706
80950c2
 
25ad706
 
 
29356cb
 
 
 
e2d6adc
4a66938
 
 
 
 
f1ee166
9fbd930
29356cb
13a4c81
 
 
 
29356cb
7129683
29356cb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gradio as gr
import spaces
import os

def resize_image(image, max_size=2048):
    width, height = image.size
    if width > max_size or height > max_size:
        aspect_ratio = width / height
        if width > height:
            new_width = max_size
            new_height = int(new_width / aspect_ratio)
        else:
            new_height = max_size
            new_width = int(new_height * aspect_ratio)
        image = image.resize((new_width, new_height), Image.LANCZOS)
    return image

def split_image(image, chunk_size=512):
    width, height = image.size
    chunks = []
    for y in range(0, height, chunk_size):
        for x in range(0, width, chunk_size):
            chunk = image.crop((x, y, min(x + chunk_size, width), min(y + chunk_size, height)))
            chunks.append((chunk, x, y))
    return chunks

def stitch_image(chunks, original_size):
    result = Image.new('RGB', original_size)
    for img, x, y in chunks:
        result.paste(img, (x, y))
    return result

def upscale_chunk(chunk, model, processor, device):
    inputs = processor(chunk, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
    output = np.moveaxis(output, source=0, destination=-1)
    output_image = (output * 255.0).round().astype(np.uint8)
    return Image.fromarray(output_image)

def remove_boundary(image, boundary=32):
    return image.crop((0, 0, image.width - boundary, image.height - boundary))

@spaces.GPU
def main(image, original_filename, model_choice, save_as_jpg=True, use_tiling=True):
    image = resize_image(image)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_paths = {
        "Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
        "PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
    }
    
    processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
    model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice]).to(device)
    
    if use_tiling:
        chunks = split_image(image)
        upscaled_chunks = []
        for chunk, x, y in chunks:
            upscaled_chunk = upscale_chunk(chunk, model, processor, device)
            upscaled_chunk = remove_boundary(upscaled_chunk)
            upscaled_chunks.append((upscaled_chunk, x * 4, y * 4))
        
        upscaled_image = stitch_image(upscaled_chunks, (image.width * 4, image.height * 4))
    else:
        upscaled_image = upscale_chunk(image, model, processor, device)
        upscaled_image = remove_boundary(upscaled_image)
    
    original_basename = os.path.splitext(original_filename)[0] if original_filename else "image"
    output_filename = f"{original_basename}_upscaled"
    
    if save_as_jpg:
        output_filename += ".jpg"
        upscaled_image.save(output_filename, quality=95)
    else:
        output_filename += ".png"
        upscaled_image.save(output_filename)
    
    return output_filename

def gradio_interface(image, model_choice, save_as_jpg, use_tiling):
    try:
        original_filename = getattr(image, 'name', 'image')
        result = main(image, original_filename, model_choice, save_as_jpg, use_tiling)
        return result, None
    except Exception as e:
        return None, str(e)

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Dropdown(
            choices=["PSNR Match (Recommended)", "Pixel Perfect"],
            label="Select Model",
            value="PSNR Match (Recommended)"
        ),
        gr.Checkbox(value=True, label="Save as JPEG"),
        gr.Checkbox(value=True, label="Use Tiling"),
    ],
    outputs=[
        gr.File(label="Download Upscaled Image"),
        gr.Textbox(label="Error Message", visible=True)
    ],
    title="Image Upscaler",
    description="Upload an image, select a model, and upscale it. Images larger than 2048x2048 will be resized while maintaining aspect ratio. Use tiling for efficient processing of large images.",
)

interface.launch()