File size: 4,245 Bytes
29356cb
 
 
5e534b3
29356cb
f1ee166
92c37e9
29356cb
0782bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
29356cb
 
13a4c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29356cb
 
4a66938
0782bc0
 
 
4a66938
 
 
 
 
 
 
 
 
29356cb
13a4c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29356cb
 
4a66938
13a4c81
 
 
 
29356cb
 
 
 
 
e2d6adc
4a66938
 
 
 
 
f1ee166
29356cb
13a4c81
 
 
 
29356cb
13a4c81
29356cb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Import necessary libraries
from PIL import Image
import numpy as np
import torch
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gradio as gr
import spaces

# Function to resize image to max 2048x2048 while maintaining aspect ratio
def resize_image(image, max_size=2048):
    width, height = image.size
    if width > max_size or height > max_size:
        aspect_ratio = width / height
        if width > height:
            new_width = max_size
            new_height = int(new_width / aspect_ratio)
        else:
            new_height = max_size
            new_width = int(new_height * aspect_ratio)
        image = image.resize((new_width, new_height), Image.LANCZOS)
    return image

# Function to upscale an image using Swin2SR
def upscale_image(image, model, processor, device):
    try:
        # Convert the image to RGB format
        image = image.convert("RGB")
        # Process the image for the model
        inputs = processor(image, return_tensors="pt")
        # Move inputs to the same device as model
        inputs = {k: v.to(device) for k, v in inputs.items()}
        # Perform inference (upscale)
        with torch.no_grad():
            outputs = model(**inputs)
        # Move output back to CPU for further processing
        output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
        output = np.moveaxis(output, source=0, destination=-1)
        output_image = (output * 255.0).round().astype(np.uint8)  # Convert from float32 to uint8
        # Remove 32 pixels from the bottom and right of the image
        output_image = output_image[:-32, :-32]
        return Image.fromarray(output_image), None
    except RuntimeError as e:
        return None, str(e)

@spaces.GPU
def main(image, model_choice, save_as_jpg=True):
    # Resize the input image
    image = resize_image(image)
    
    # Define model paths
    model_paths = {
        "Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
        "PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
    }
    
    # Load the selected Swin2SR model and processor for 4x upscaling
    processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
    model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice])
    
    # Try GPU first, fallback to CPU if there's an error
    for device in [torch.device("cuda" if torch.cuda.is_available() else "cpu"), torch.device("cpu")]:
        model.to(device)
        upscaled_image, error = upscale_image(image, model, processor, device)
        
        if upscaled_image is not None:
            if save_as_jpg:
                # Save the upscaled image as JPG with 98% compression
                upscaled_image.save("upscaled_image.jpg", quality=98)
                return "upscaled_image.jpg"
            else:
                # Save the upscaled image as PNG
                upscaled_image.save("upscaled_image.png")
                return "upscaled_image.png"
        
        if device.type == "cpu":
            return f"Error: Unable to process the image. {error}"

    return "Error: Unable to process the image on both GPU and CPU."

# Gradio interface
def gradio_interface(image, model_choice, save_as_jpg):
    result = main(image, model_choice, save_as_jpg)
    if result.startswith("Error:"):
        return gr.update(value=None), result
    return result, None

# Create a Gradio interface
interface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Dropdown(
            choices=["PSNR Match (Recommended)", "Pixel Perfect"],
            label="Select Model",
            value="PSNR Match (Recommended)"
        ),
        gr.Checkbox(value=True, label="Save as JPEG"),
    ],
    outputs=[
        gr.File(label="Download Upscaled Image"),
        gr.Textbox(label="Error Message", visible=True)
    ],
    title="Image Upscaler",
    description="Upload an image, select a model, upscale it, and download the new image. Images larger than 2048x2048 will be resized while maintaining aspect ratio. If GPU processing fails, it will attempt to process on CPU.",
)

# Launch the interface
interface.launch()