juaben commited on
Commit
7b39006
·
verified ·
1 Parent(s): 59b43a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -15
app.py CHANGED
@@ -6,16 +6,20 @@ import numpy as np
6
  from PIL import Image
7
  from model.flol import create_model
8
 
9
-
10
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
11
  #define some auxiliary functions
12
  pil_to_tensor = transforms.ToTensor()
13
 
14
- # Define a list of available weight files
15
- weights_files = {
16
- "flolv2_UHDLL": './weights/flolv2_UHDLL.pt',
17
- "flolv2_LOLv2-Real": './weights/flolv2_all_111439.pt', # Add other available weights here
18
- }
 
 
 
 
 
19
 
20
  # Initial model setup (without weights)
21
  model = create_model()
@@ -25,10 +29,10 @@ def load_img(filename):
25
  img_tensor = pil_to_tensor(img)
26
  return img_tensor
27
 
28
- def process_img(image, weight_file):
29
- # Load the selected weight file
30
- if weight_file in weights_files:
31
- model_path = weights_files[weight_file]
32
  checkpoints = torch.load(model_path, map_location=device)
33
  model.load_state_dict(checkpoints['params'])
34
  model.to(device)
@@ -59,9 +63,9 @@ Due to the GPU memory limitations, the app might crash if you feed a high-resolu
59
  '''
60
 
61
  examples = [['images/425_UHD_LL.JPG'],
62
- ['images/674_UHD_LL.JPG'],
63
- ['images/685_UHD_LL.JPG'],
64
- ['images/945_UHD_LL.JPG'],
65
  ['images/1778_UHD_LL.JPG'],
66
  ['images/1791_UHD_LL.JPG']]
67
 
@@ -76,8 +80,8 @@ css = """
76
  demo = gr.Interface(
77
  fn=process_img,
78
  inputs=[
79
- gr.Image(type='pil', label='input'),
80
- gr.Dropdown(choices=list(weights_files.keys()), label='Select Weight File', default="flolv2_UHDLL")
81
  ],
82
  outputs=[gr.Image(type='pil', label='output')],
83
  title=title,
@@ -86,5 +90,16 @@ demo = gr.Interface(
86
  css=css
87
  )
88
 
 
 
 
 
 
 
 
 
 
 
 
89
  if __name__ == '__main__':
90
  demo.launch()
 
6
  from PIL import Image
7
  from model.flol import create_model
8
 
 
9
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
10
  #define some auxiliary functions
11
  pil_to_tensor = transforms.ToTensor()
12
 
13
+ # Define a dictionary to map image filenames to weight files
14
+ image_to_weights = {
15
+ "425_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
16
+ "1778_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
17
+ "1791_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
18
+
19
+ "low00748.png": './weights/flolv2_all_111439.pt',
20
+ "low00723.png": './weights/flolv2_all_111439.pt',
21
+ "low00772.png": './weights/flolv2_all_111439.pt'
22
+ }
23
 
24
  # Initial model setup (without weights)
25
  model = create_model()
 
29
  img_tensor = pil_to_tensor(img)
30
  return img_tensor
31
 
32
+ def process_img(image, filename):
33
+ # Select the correct weight file based on the image filename
34
+ if filename in image_to_weights:
35
+ model_path = image_to_weights[filename]
36
  checkpoints = torch.load(model_path, map_location=device)
37
  model.load_state_dict(checkpoints['params'])
38
  model.to(device)
 
63
  '''
64
 
65
  examples = [['images/425_UHD_LL.JPG'],
66
+ ['images/low00772.png'],
67
+ ['images/low00723.png'],
68
+ ['images/low00748.png'],
69
  ['images/1778_UHD_LL.JPG'],
70
  ['images/1791_UHD_LL.JPG']]
71
 
 
80
  demo = gr.Interface(
81
  fn=process_img,
82
  inputs=[
83
+ gr.Image(type='pil', label='input', tool='editor'),
84
+ gr.Textbox(label="Image Filename", interactive=False)
85
  ],
86
  outputs=[gr.Image(type='pil', label='output')],
87
  title=title,
 
90
  css=css
91
  )
92
 
93
+ # Updating the filename in the input after selection
94
+ def update_filename(image):
95
+ # Retrieve the filename from the input image
96
+ if image:
97
+ filename = image.filename # Gradio automatically gives the file name
98
+ return filename
99
+ return ""
100
+
101
+ # Define the update logic for filename
102
+ demo.input_components[0].change(update_filename, inputs=demo.input_components[0], outputs=demo.input_components[1])
103
+
104
  if __name__ == '__main__':
105
  demo.launch()