juaben commited on
Commit
f3c4d54
·
verified ·
1 Parent(s): d5c5d2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -27
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import gradio as gr
2
  import torch
3
  import torchvision.transforms as transforms
4
  import numpy as np
@@ -10,46 +10,51 @@ from model.flol import create_model
10
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
11
  #define some auxiliary functions
12
  pil_to_tensor = transforms.ToTensor()
13
-
14
- # define some parameters based on the run we want to make
15
 
 
 
 
 
 
 
 
16
  model = create_model()
17
- checkpoints = torch.load('./weights/flolv2_UHDLL.pt', map_location=device)
18
- model.load_state_dict(checkpoints['params'])
19
- model = model.to(device)
20
 
21
- def load_img (filename):
22
  img = Image.open(filename).convert("RGB")
23
  img_tensor = pil_to_tensor(img)
24
  return img_tensor
25
 
26
- def process_img(image):
 
 
 
 
 
 
 
27
  img = np.array(image)
28
- img = img / 255.
29
  img = img.astype(np.float32)
30
- y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
31
 
32
  with torch.no_grad():
33
  x_hat = model(y)
34
 
35
- restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
36
- restored_img = np.clip(restored_img, 0. , 1.)
 
 
 
37
 
38
- restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
39
- return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
40
 
41
  title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗"
42
  description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet)
43
-
44
  [Juan Carlos Benito](https://github.com/juaben)
45
-
46
  Fundación Cidaut
47
-
48
-
49
  > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
50
  **This demo expects an image with some degradations.**
51
  Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
52
-
53
  <br>
54
  '''
55
 
@@ -69,15 +74,16 @@ css = """
69
  """
70
 
71
  demo = gr.Interface(
72
- fn = process_img,
73
- inputs = [
74
- gr.Image(type = 'pil', label = 'input')
 
75
  ],
76
- outputs = [gr.Image(type='pil', label = 'output')],
77
- title = title,
78
- description = description,
79
- examples = examples,
80
- css = css
81
  )
82
 
83
  if __name__ == '__main__':
 
1
+ import gradio as gr
2
  import torch
3
  import torchvision.transforms as transforms
4
  import numpy as np
 
10
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
11
  #define some auxiliary functions
12
  pil_to_tensor = transforms.ToTensor()
 
 
13
 
14
+ # Define a list of available weight files
15
+ weights_files = {
16
+ "flolv2_UHDLL": './weights/flolv2_UHDLL.pt',
17
+ "flolv2_other": './weights/flolv2_other.pt', # Add other available weights here
18
+ }
19
+
20
+ # Initial model setup (without weights)
21
  model = create_model()
 
 
 
22
 
23
+ def load_img(filename):
24
  img = Image.open(filename).convert("RGB")
25
  img_tensor = pil_to_tensor(img)
26
  return img_tensor
27
 
28
+ def process_img(image, weight_file):
29
+ # Load the selected weight file
30
+ if weight_file in weights_files:
31
+ model_path = weights_files[weight_file]
32
+ checkpoints = torch.load(model_path, map_location=device)
33
+ model.load_state_dict(checkpoints['params'])
34
+ model.to(device)
35
+
36
  img = np.array(image)
37
+ img = img / 255. # Normalize to [0, 1]
38
  img = img.astype(np.float32)
39
+ y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device)
40
 
41
  with torch.no_grad():
42
  x_hat = model(y)
43
 
44
+ restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy()
45
+ restored_img = np.clip(restored_img, 0., 1.)
46
+
47
+ restored_img = (restored_img * 255.0).round().astype(np.uint8) # Convert to uint8
48
+ return Image.fromarray(restored_img)
49
 
 
 
50
 
51
  title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗"
52
  description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet)
 
53
  [Juan Carlos Benito](https://github.com/juaben)
 
54
  Fundación Cidaut
 
 
55
  > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
56
  **This demo expects an image with some degradations.**
57
  Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
 
58
  <br>
59
  '''
60
 
 
74
  """
75
 
76
  demo = gr.Interface(
77
+ fn=process_img,
78
+ inputs=[
79
+ gr.Image(type='pil', label='input'),
80
+ gr.Dropdown(choices=list(weights_files.keys()), label='Select Weight File', default="flolv2_UHDLL")
81
  ],
82
+ outputs=[gr.Image(type='pil', label='output')],
83
+ title=title,
84
+ description=description,
85
+ examples=examples,
86
+ css=css
87
  )
88
 
89
  if __name__ == '__main__':