juaben commited on
Commit
2c5c388
·
verified ·
1 Parent(s): 578ffd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -37
app.py CHANGED
@@ -11,15 +11,7 @@ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cp
11
  pil_to_tensor = transforms.ToTensor()
12
 
13
  # Define a dictionary to map image filenames to weight files
14
- image_to_weights = {
15
- "425_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
16
- "1778_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
17
- "1791_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
18
-
19
- "low00748.png": './weights/flolv2_all_111439.pt',
20
- "low00723.png": './weights/flolv2_all_111439.pt',
21
- "low00772.png": './weights/flolv2_all_111439.pt'
22
- }
23
 
24
  # Initial model setup (without weights)
25
  model = create_model()
@@ -32,12 +24,12 @@ def load_img(filename):
32
  def process_img(image, filename):
33
  # Select the correct weight file based on the image filename
34
 
35
- filename = image.name.split("/")[-1]
36
- if filename in image_to_weights:
37
- model_path = image_to_weights[filename]
38
- checkpoints = torch.load(model_path, map_location=device)
39
- model.load_state_dict(checkpoints['params'])
40
- model.to(device)
41
 
42
  img = np.array(image)
43
  img = img / 255. # Normalize to [0, 1]
@@ -64,12 +56,11 @@ Due to the GPU memory limitations, the app might crash if you feed a high-resolu
64
  <br>
65
  '''
66
 
67
- examples = [['images/425_UHD_LL.JPG', '425_UHD_LL.JPG'],
68
- ['images/low00772.png', 'low00772.png'],
69
- ['images/low00723.png', 'low00723.png'],
70
- ['images/low00748.png', 'low00748.png'],
71
- ['images/1778_UHD_LL.JPG', '1778_UHD_LL.JPG'],
72
- ['images/1791_UHD_LL.JPG', '1791_UHD_LL.JPG']]
73
 
74
  css = """
75
  .image-frame img, .image-container img {
@@ -81,22 +72,17 @@ css = """
81
 
82
 
83
 
84
- with gr.Blocks() as demo:
85
- with gr.Row():
86
- input_image = gr.Image(type='pil', label='input', interactive = True)
87
- filename_output = gr.Textbox(label="Image Filename", interactive=False)
88
-
89
- # Define the output
90
- output_image = gr.Image(type='pil', label='output')
91
-
92
- # Define the interaction flow: when the image is uploaded, update the filename
93
- input_image.change(update_filename, inputs=input_image, outputs=filename_output)
94
-
95
- # Set the function to process the image with the filename
96
- input_image.change(process_img, inputs=[input_image, filename_output], outputs=output_image)
97
-
98
- # Provide examples
99
- demo.examples = examples
100
 
101
  if __name__ == '__main__':
102
  demo.launch()
 
11
  pil_to_tensor = transforms.ToTensor()
12
 
13
  # Define a dictionary to map image filenames to weight files
14
+ image_to_weights = {'./weights/flolv2_UHDLL.pt','./weights/flolv2_all_111439.pt'}
 
 
 
 
 
 
 
 
15
 
16
  # Initial model setup (without weights)
17
  model = create_model()
 
24
  def process_img(image, filename):
25
  # Select the correct weight file based on the image filename
26
 
27
+ # filename = image.name.split("/")[-1]
28
+ # if filename in image_to_weights:
29
+ model_path = image_to_weights[image_to_weights[1]]
30
+ checkpoints = torch.load(model_path, map_location=device)
31
+ model.load_state_dict(checkpoints['params'])
32
+ model.to(device)
33
 
34
  img = np.array(image)
35
  img = img / 255. # Normalize to [0, 1]
 
56
  <br>
57
  '''
58
 
59
+ examples = [
60
+ ['images/low00772.png'],
61
+ ['images/low00723.png'],
62
+ ['images/low00748.png'],
63
+ ]
 
64
 
65
  css = """
66
  .image-frame img, .image-container img {
 
72
 
73
 
74
 
75
+ demo = gr.Interface(
76
+ fn = process_img,
77
+ inputs = [
78
+ gr.Image(type = 'pil', label = 'input')
79
+ ],
80
+ outputs = [gr.Image(type='pil', label = 'output')],
81
+ title = title,
82
+ description = description,
83
+ examples = examples,
84
+ css = css
85
+ )
 
 
 
 
 
86
 
87
  if __name__ == '__main__':
88
  demo.launch()