File size: 3,507 Bytes
dc15506 fca6e54 ddad5e5 196a89d fca6e54 8cf25d5 41ae325 8cf25d5 fca6e54 9741eb4 a90a2a6 ceafe2f fca6e54 0d48277 fca6e54 a1caa30 fca6e54 d085511 fca6e54 edbecfe fca6e54 ddad5e5 67ea445 ddad5e5 2e1828f a642a77 ddad5e5 fca6e54 dc15506 fca6e54 dc15506 fca6e54 196a89d 6ac1620 fca6e54 59c537d fca6e54 59c537d fca6e54 86dc6a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
import torch.nn.functional as F
import os
from skimage import img_as_ubyte
import cv2
import argparse
import shutil
import gradio as gr
from PIL import Image
from runpy import run_path
import numpy as np
examples = [['./sample1.png'],['./sample2.png'],['./Sample3.png'],['./Sample4.png'],['./Sample5.png'],['./Sample6.png']
]
title = "Restormer"
description = """
Gradio demo for reconstruction of noisy scanned, photocopied documents\n
using <b>Restormer: Efficient Transformer for High-Resolution Image Restoration</b>, CVPR 2022--ORAL. <a href='https://arxiv.org/abs/2111.09881'>[Paper]</a><a href='https://github.com/swz30/Restormer'>[Github Code]</a>\n
<a href='https://toon-beerten.medium.com/denoising-and-reconstructing-dirty-documents-for-optimal-digitalization-ed3a186aa3d6'>[See my article for more details]</a>\n
<b> Note:</b> Since this demo uses CPU, by default it will run on the downsampled version of the input image (for speedup).
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.09881'>Restormer: Efficient Transformer for High-Resolution Image Restoration </a> | <a href='https://github.com/swz30/Restormer'>Github Repo</a></p>"
def inference(img):
if not os.path.exists('temp'):
os.system('mkdir temp')
# 'Downsampled Image'
#### Resize the longer edge of the input image
max_res = 400
width, height = img.size
if max(width,height) > max_res:
scale = max_res /max(width,height)
width = int(scale*width)
height = int(scale*height)
img = img.resize((width,height))
parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
load_arch = run_path('restormer_arch.py')
model = load_arch['Restormer'](**parameters)
checkpoint = torch.load('net_g_92000.pth')
model.load_state_dict(checkpoint['params'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
img_multiple_of = 8
with torch.inference_mode():
if torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
open_cv_image = np.array(img)
img = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
# Pad the input if not_multiple_of 8
h,w = input_.shape[2], input_.shape[3]
H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
padh = H-h if h%img_multiple_of!=0 else 0
padw = W-w if w%img_multiple_of!=0 else 0
input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
restored = torch.clamp(model(input_),0,1)
# Unpad the output
restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0])
#convert to pil when returning
return Image.fromarray(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
gr.Interface(
inference,
[
gr.Image(type="pil", label="Input"),
],
gr.Image(type="pil", label="cleaned and restored"),
title=title,
description=description,
article=article,
examples=examples,
).launch(debug=False)
|