File size: 4,046 Bytes
25ad706
29356cb
 
 
f1ee166
92c37e9
29356cb
337146c
 
 
 
 
 
 
 
 
 
 
 
 
25ad706
0782bc0
25ad706
 
 
 
 
 
0782bc0
25ad706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29356cb
337146c
4a66938
337146c
 
 
25ad706
0782bc0
4a66938
 
 
 
 
 
25ad706
29356cb
25ad706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29356cb
4a66938
25ad706
 
 
 
 
29356cb
 
 
 
e2d6adc
4a66938
 
 
 
 
f1ee166
29356cb
13a4c81
 
 
 
29356cb
337146c
29356cb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gradio as gr
import spaces

def resize_image(image, max_size=2048):
    width, height = image.size
    if width > max_size or height > max_size:
        aspect_ratio = width / height
        if width > height:
            new_width = max_size
            new_height = int(new_width / aspect_ratio)
        else:
            new_height = max_size
            new_width = int(new_height * aspect_ratio)
        image = image.resize((new_width, new_height), Image.LANCZOS)
    return image

def split_image(image, chunk_size=512):
    width, height = image.size
    chunks = []
    for y in range(0, height, chunk_size):
        for x in range(0, width, chunk_size):
            chunk = image.crop((x, y, min(x + chunk_size, width), min(y + chunk_size, height)))
            chunks.append((chunk, x, y))
    return chunks

def stitch_image(chunks, original_size):
    result = Image.new('RGB', original_size)
    for img, x, y in chunks:
        result.paste(img, (x, y))
    return result

def upscale_chunk(chunk, model, processor, device):
    inputs = processor(chunk, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
    output = np.moveaxis(output, source=0, destination=-1)
    output_image = (output * 255.0).round().astype(np.uint8)
    return Image.fromarray(output_image)

@spaces.GPU(duration=120)
def main(image, model_choice, save_as_jpg=True):
    # Resize the input image
    image = resize_image(image)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_paths = {
        "Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
        "PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
    }
    
    processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
    model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice]).to(device)
    
    # Split the image into chunks
    chunks = split_image(image)
    
    # Process each chunk
    upscaled_chunks = []
    for chunk, x, y in chunks:
        upscaled_chunk = upscale_chunk(chunk, model, processor, device)
        # Remove 32 pixels from bottom and right edges
        upscaled_chunk = upscaled_chunk.crop((0, 0, upscaled_chunk.width - 32, upscaled_chunk.height - 32))
        upscaled_chunks.append((upscaled_chunk, x * 4, y * 4))  # Multiply coordinates by 4 due to 4x upscaling
    
    # Stitch the chunks back together
    final_size = (image.width * 4 - 32, image.height * 4 - 32)  # Adjust for removed pixels
    upscaled_image = stitch_image(upscaled_chunks, final_size)
    
    if save_as_jpg:
        upscaled_image.save("upscaled_image.jpg", quality=95)
        return "upscaled_image.jpg"
    else:
        upscaled_image.save("upscaled_image.png")
        return "upscaled_image.png"

def gradio_interface(image, model_choice, save_as_jpg):
    try:
        result = main(image, model_choice, save_as_jpg)
        return result, None
    except Exception as e:
        return None, str(e)

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Dropdown(
            choices=["PSNR Match (Recommended)", "Pixel Perfect"],
            label="Select Model",
            value="PSNR Match (Recommended)"
        ),
        gr.Checkbox(value=True, label="Save as JPEG"),
    ],
    outputs=[
        gr.File(label="Download Upscaled Image"),
        gr.Textbox(label="Error Message", visible=True)
    ],
    title="Image Upscaler",
    description="Upload an image, select a model, and upscale it. Images larger than 2048x2048 will be resized while maintaining aspect ratio. The image will be processed in 512x512 pixel chunks for efficient handling.",
)

interface.launch()