danifei commited on
Commit
1519a8d
·
verified ·
1 Parent(s): 4fcebd2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+ import numpy as np
6
+
7
+ from archs.model import FourNet
8
+
9
+
10
+ opt = parse(path_opt)
11
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
12
+ #define some auxiliary functions
13
+ pil_to_tensor = transforms.ToTensor()
14
+
15
+ # define some parameters based on the run we want to make
16
+
17
+ model = FourNet(nf = 16)
18
+
19
+ checkpoints = torch.load('./models/NAFourNet16_LOLv2Real.pt', map_location=device)
20
+
21
+ model.load_state_dict(checkpoints['model_state_dict'])
22
+
23
+ model = model.to(device)
24
+
25
+ def load_img (filename):
26
+ img = Image.open(filename).convert("RGB")
27
+ img_tensor = pil_to_tensor(img)
28
+ return img_tensor
29
+
30
+ def process_img(image):
31
+ img = np.array(image)
32
+ img = img / 255.
33
+ img = img.astype(np.float32)
34
+ y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
35
+
36
+ with torch.no_grad():
37
+ x_hat = model(y)
38
+
39
+ restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
40
+ restored_img = np.clip(restored_img, 0. , 1.)
41
+
42
+ restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
43
+ return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
44
+
45
+ title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗"
46
+ description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet)
47
+
48
+ [Juan Carlos Benito](https://github.com/juaben)
49
+
50
+ Fundación Cidaut
51
+
52
+
53
+ > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
54
+ **This demo expects an image with some degradations.**
55
+ Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
56
+
57
+ <br>
58
+ '''
59
+
60
+ examples = [['examples/inputs/0010.png'],
61
+ ['examples/inputs/0060.png'],
62
+ ['examples/inputs/0075.png'],
63
+ ["examples/inputs/0087.png"],
64
+ ["examples/inputs/0088.png"]]
65
+
66
+ css = """
67
+ .image-frame img, .image-container img {
68
+ width: auto;
69
+ height: auto;
70
+ max-width: none;
71
+ }
72
+ """
73
+
74
+ demo = gr.Interface(
75
+ fn = process_img,
76
+ inputs = [
77
+ gr.Image(type = 'pil', label = 'input')
78
+ ],
79
+ outputs = [gr.Image(type='pil', label = 'output')],
80
+ title = title,
81
+ description = description,
82
+ examples = examples,
83
+ css = css
84
+ )
85
+
86
+ if __name__ == '__main__':
87
+ demo.launch()