VOIDER commited on
Commit
5157b8f
Β·
verified Β·
1 Parent(s): dbf8e93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -64
app.py CHANGED
@@ -25,7 +25,7 @@ try:
25
  # "BRISQUE-PyTorch", "NIQE-PyTorch"
26
  # "NIMA-VGG16-estimate", "NIMA-MobileNet-estimate" (Aesthetic)
27
  except ImportError:
28
- print("Warning: IQA-PyTorch library not found. Some metrics (NIQE, MUSIQ-NR) will be unavailable.")
29
  IQA = None
30
 
31
  # --- Configuration ---
@@ -39,20 +39,17 @@ def get_brisque_score(img_tensor_chw_01):
39
  """Calculates BRISQUE score using PIQ. Expects a (C, H, W) tensor, range [0, 1]."""
40
  if piq is None: return "N/A (PIQ missing)"
41
  try:
42
- # Ensure tensor is (B, C, H, W) for piq.brisque
43
  if img_tensor_chw_01.ndim == 3:
44
  img_tensor_bchw_01 = img_tensor_chw_01.unsqueeze(0)
45
- else: # Already has batch dim or incorrect dims
46
  img_tensor_bchw_01 = img_tensor_chw_01
47
 
48
- # Ensure 3 channels if it's grayscale by repeating
49
  if img_tensor_bchw_01.shape[1] == 1:
50
  img_tensor_bchw_01 = img_tensor_bchw_01.repeat(1, 3, 1, 1)
51
 
52
  brisque_loss = piq.brisque(img_tensor_bchw_01.to(DEVICE), data_range=1.)
53
  return round(brisque_loss.item(), 3)
54
  except Exception as e:
55
- # print(f"BRISQUE Error: {e} for tensor shape {img_tensor_chw_01.shape}")
56
  return f"Error"
57
 
58
 
@@ -64,20 +61,16 @@ def get_niqe_score(img_pil_rgb):
64
  score = niqe_metric(img_pil_rgb)
65
  return round(score.item(), 3)
66
  except Exception as e:
67
- # print(f"NIQE Error: {e}")
68
  return f"Error"
69
 
70
  def get_musiq_nr_score(img_pil_rgb):
71
  """Calculates No-Reference MUSIQ score using IQA-PyTorch. Expects a PIL RGB image."""
72
  if IQA is None: return "N/A (IQA missing)"
73
  try:
74
- # Using MUSIQ-L2N-lessons as an example NR model from IQA-PyTorch
75
- # Other options: "MUSIQ-Koniq-NSR", "MUSIQ-SpAq-NSR"
76
  musiq_metric = IQA(libs='MUSIQ-L2N-lessons', device=DEVICE)
77
  score = musiq_metric(img_pil_rgb)
78
  return round(score.item(), 3)
79
  except Exception as e:
80
- # print(f"MUSIQ-NR Error: {e}")
81
  return f"Error"
82
 
83
 
@@ -85,7 +78,6 @@ def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder):
85
  """Calculates FID between two folders of images using PIQ."""
86
  if piq is None: return "N/A (PIQ missing)"
87
  try:
88
- # List image files in folders
89
  set1_files = [os.path.join(path_to_set1_folder, f) for f in os.listdir(path_to_set1_folder)
90
  if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
91
  set2_files = [os.path.join(path_to_set2_folder, f) for f in os.listdir(path_to_set2_folder)
@@ -93,21 +85,19 @@ def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder):
93
 
94
  if not set1_files or not set2_files:
95
  return "One or both sets have no valid image files."
96
- if len(set1_files) < 2 or len(set2_files) < 2: # FID usually needs more, but PIQ might handle small N. Min 2 to compute stats.
97
  return f"FID needs at least 2 images per set. Found: Set1={len(set1_files)}, Set2={len(set2_files)}."
98
 
99
  fid_metric = piq.FID()
100
- # compute_feats expects a list of image paths
101
  set1_features = fid_metric.compute_feats(set1_files, device=DEVICE)
102
  set2_features = fid_metric.compute_feats(set2_files, device=DEVICE)
103
 
104
  if set1_features is None or set2_features is None:
105
  return "Could not extract features for one or both sets (check image validity and count)."
106
- if set1_features.dim() == 0 or set2_features.dim() == 0 or set1_features.numel() == 0 or set2_features.numel() == 0: # Handle empty tensors
107
  return "Feature extraction resulted in empty tensors."
108
 
109
-
110
- fid_value = fid_metric(set1_features, set2_features) # Pass tensors directly
111
  return round(fid_value.item(), 3)
112
  except Exception as e:
113
  print(f"FID calculation error: {e}")
@@ -115,12 +105,10 @@ def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder):
115
 
116
  # --- Helper Functions ---
117
  def pil_to_tensor_chw_01(img_pil_rgb):
118
- """Converts PIL RGB image to PyTorch CHW tensor [0,1]."""
119
- transform = T.Compose([T.ToTensor()]) # Converts PIL [0,255] to Tensor [0,1] C,H,W
120
  return transform(img_pil_rgb)
121
 
122
  def create_thumbnail_base64(img_pil_rgb, size=THUMBNAIL_SIZE):
123
- """Creates a base64 encoded PNG thumbnail string from a PIL image."""
124
  img_copy = img_pil_rgb.copy()
125
  img_copy.thumbnail(size)
126
  buffered = io.BytesIO()
@@ -131,11 +119,9 @@ def create_thumbnail_base64(img_pil_rgb, size=THUMBNAIL_SIZE):
131
  # --- Main Processing Functions for Gradio ---
132
 
133
  def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progress(track_tqdm=True)):
134
- """Processes uploaded images for individual quality scores and displays them."""
135
  if not uploaded_file_list:
136
- return pd.DataFrame(), "Please upload images first.", "IS: N/A (Not Implemented)", "FID: N/A (Use FID Tab)"
137
 
138
- # Limit number of images
139
  if len(uploaded_file_list) > MAX_IMAGES_PER_BATCH:
140
  status_message = f"Too many images ({len(uploaded_file_list)}). Processing first {MAX_IMAGES_PER_BATCH} images."
141
  uploaded_file_list = uploaded_file_list[:MAX_IMAGES_PER_BATCH]
@@ -143,28 +129,18 @@ def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progres
143
  status_message = f"Processing {len(uploaded_file_list)} images..."
144
 
145
  progress(0, desc=status_message)
146
-
147
  results_data = []
148
- # Temporary directory for this batch if needed by some metric that takes a folder path
149
- # batch_temp_dir = tempfile.mkdtemp(prefix="eval_batch_")
150
 
151
  for i, file_obj in enumerate(uploaded_file_list):
152
  try:
153
- # file_obj for gr.Files is a tempfile._TemporaryFileWrapper object
154
  file_path = file_obj.name
155
  base_filename = os.path.basename(file_path)
156
-
157
  img_pil_rgb = Image.open(file_path).convert("RGB")
158
 
159
- # 1. For PIQ BRISQUE (needs tensor)
160
  img_tensor_chw_01 = pil_to_tensor_chw_01(img_pil_rgb)
161
  brisque_val = get_brisque_score(img_tensor_chw_01)
162
-
163
- # 2. For IQA-PyTorch NIQE & MUSIQ (needs PIL image)
164
  niqe_val = get_niqe_score(img_pil_rgb)
165
  musiq_nr_val = get_musiq_nr_score(img_pil_rgb)
166
-
167
- # 3. Thumbnail for display
168
  thumbnail_b64 = create_thumbnail_base64(img_pil_rgb)
169
  preview_html = f'<img src="{thumbnail_b64}" alt="{base_filename}">'
170
 
@@ -188,33 +164,19 @@ def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progres
188
 
189
  df_results = pd.DataFrame(results_data)
190
  status_message += f"\nPer-image metrics calculated for {len(results_data)} images."
191
-
192
- # Batch metrics info (IS not implemented, FID separate)
193
- is_text = "IS (PIQ): Not implemented in this version."
194
- fid_text_batch_info = "FID (PIQ): Use the 'Calculate FID (Set vs Set)' tab for FID scores."
195
-
196
- # Cleanup temp dir if created
197
- # if os.path.exists(batch_temp_dir): shutil.rmtree(batch_temp_dir)
198
-
199
- return df_results, status_message, is_text, fid_text_batch_info
200
 
201
 
202
  def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progress(track_tqdm=True)):
203
- """Handles FID calculation between two sets of uploaded images."""
204
  if not set1_file_list or not set2_file_list:
205
  return "Please upload files for both Set 1 and Set 2."
206
 
207
- # Create temporary directories for Set 1 and Set 2
208
- # Suffix helps identify user folders if many users hit it, though Gradio handles sessions.
209
- # Prefix is better for mkdtemp.
210
  set1_dir = tempfile.mkdtemp(prefix="fid_set1_")
211
  set2_dir = tempfile.mkdtemp(prefix="fid_set2_")
212
-
213
  fid_result_text = "Starting FID calculation..."
214
  progress(0.1, desc="Preparing image sets for FID...")
215
 
216
  try:
217
- # Copy uploaded files to these temporary directories
218
  for i, file_obj in enumerate(set1_file_list):
219
  shutil.copy(file_obj.name, os.path.join(set1_dir, os.path.basename(file_obj.name)))
220
  progress(0.1 + 0.2 * (i / len(set1_file_list)), desc=f"Copying Set 1: {os.path.basename(file_obj.name)}")
@@ -237,10 +199,8 @@ def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progres
237
  except Exception as e:
238
  fid_result_text = f"Error during FID preparation or calculation: {str(e)}"
239
  finally:
240
- # Cleanup temporary directories
241
  if os.path.exists(set1_dir): shutil.rmtree(set1_dir)
242
  if os.path.exists(set2_dir): shutil.rmtree(set2_dir)
243
-
244
  return fid_result_text
245
 
246
  # --- Gradio UI Definition ---
@@ -264,7 +224,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css_custom) as demo:
264
  image_upload_input = gr.Files(
265
  label=f"Upload Images (max {MAX_IMAGES_PER_BATCH}, .png, .jpg, .jpeg, .bmp, .webp)",
266
  file_count="multiple",
267
- type="filepath"
268
  )
269
 
270
  evaluate_button_main = gr.Button("πŸ–ΌοΈ Evaluate Uploaded Images", variant="primary")
@@ -274,14 +234,12 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css_custom) as demo:
274
 
275
  gr.Markdown("### πŸ–ΌοΈ Per-Image Evaluation Results")
276
  gr.Markdown("Click column headers to sort. Previews are thumbnails.")
277
- # MODIFIED LINE BELOW:
278
  results_table_output = gr.DataFrame(
279
  headers=["Preview", "Filename", "BRISQUE (PIQ) (↓)", "NIQE (IQA-PyTorch) (↓)", "MUSIQ-NR (IQA-PyTorch) (↑)"],
280
  datatype=["html", "str", "number", "number", "number"],
281
- interactive=False,
282
- wrap=True,
283
- overflow_row_behaviour="paginate", # max_rows removed
284
- # height=450 # Optional: Set a fixed height in pixels if you want ~15 rows visible before scrolling within the component
285
  )
286
 
287
  with gr.TabItem("↔️ Calculate FID (Set vs. Set)"):
@@ -297,11 +255,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css_custom) as demo:
297
  fid_calculate_button = gr.Button("πŸ”— Calculate FID between Set 1 and Set 2", variant="primary")
298
  fid_result_output = gr.Textbox(label="πŸ“ˆ FID Result", interactive=False, lines=2)
299
 
300
- # Wire components
301
  evaluate_button_main.click(
302
  fn=process_images_for_individual_scores,
303
  inputs=[image_upload_input],
304
- outputs=[results_table_output, status_output_main] #, batch_is_output, batch_fid_output_info]
305
  )
306
 
307
  fid_calculate_button.click(
@@ -318,16 +275,19 @@ torch
318
  torchvision
319
  Pillow
320
  numpy
321
- piq>=0.8.0 # Specify version if known good, or just piq
322
- iqa-pytorch>=0.2.1 # Specify version if known good
323
- timm # A dependency for some iqa-pytorch models like MUSIQ
324
- scikit-image # Often a transitive dependency, good to include
325
  pandas
326
  """
327
 
328
  if __name__ == "__main__":
329
- if piq is None or IQA is None:
330
- print("\n\nWARNING: One or more core libraries (PIQ, IQA-PyTorch) are missing.")
331
- print("Please install them by creating a 'requirements.txt' file with the content above and running: pip install -r requirements.txt\n\n")
 
 
 
332
 
333
- demo.launch(debug=True) # Set debug=False for production
 
25
  # "BRISQUE-PyTorch", "NIQE-PyTorch"
26
  # "NIMA-VGG16-estimate", "NIMA-MobileNet-estimate" (Aesthetic)
27
  except ImportError:
28
+ print("ERROR: IQA-PyTorch library import failed. Some metrics (NIQE, MUSIQ-NR) will be unavailable. Check installation.")
29
  IQA = None
30
 
31
  # --- Configuration ---
 
39
  """Calculates BRISQUE score using PIQ. Expects a (C, H, W) tensor, range [0, 1]."""
40
  if piq is None: return "N/A (PIQ missing)"
41
  try:
 
42
  if img_tensor_chw_01.ndim == 3:
43
  img_tensor_bchw_01 = img_tensor_chw_01.unsqueeze(0)
44
+ else:
45
  img_tensor_bchw_01 = img_tensor_chw_01
46
 
 
47
  if img_tensor_bchw_01.shape[1] == 1:
48
  img_tensor_bchw_01 = img_tensor_bchw_01.repeat(1, 3, 1, 1)
49
 
50
  brisque_loss = piq.brisque(img_tensor_bchw_01.to(DEVICE), data_range=1.)
51
  return round(brisque_loss.item(), 3)
52
  except Exception as e:
 
53
  return f"Error"
54
 
55
 
 
61
  score = niqe_metric(img_pil_rgb)
62
  return round(score.item(), 3)
63
  except Exception as e:
 
64
  return f"Error"
65
 
66
  def get_musiq_nr_score(img_pil_rgb):
67
  """Calculates No-Reference MUSIQ score using IQA-PyTorch. Expects a PIL RGB image."""
68
  if IQA is None: return "N/A (IQA missing)"
69
  try:
 
 
70
  musiq_metric = IQA(libs='MUSIQ-L2N-lessons', device=DEVICE)
71
  score = musiq_metric(img_pil_rgb)
72
  return round(score.item(), 3)
73
  except Exception as e:
 
74
  return f"Error"
75
 
76
 
 
78
  """Calculates FID between two folders of images using PIQ."""
79
  if piq is None: return "N/A (PIQ missing)"
80
  try:
 
81
  set1_files = [os.path.join(path_to_set1_folder, f) for f in os.listdir(path_to_set1_folder)
82
  if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
83
  set2_files = [os.path.join(path_to_set2_folder, f) for f in os.listdir(path_to_set2_folder)
 
85
 
86
  if not set1_files or not set2_files:
87
  return "One or both sets have no valid image files."
88
+ if len(set1_files) < 2 or len(set2_files) < 2:
89
  return f"FID needs at least 2 images per set. Found: Set1={len(set1_files)}, Set2={len(set2_files)}."
90
 
91
  fid_metric = piq.FID()
 
92
  set1_features = fid_metric.compute_feats(set1_files, device=DEVICE)
93
  set2_features = fid_metric.compute_feats(set2_files, device=DEVICE)
94
 
95
  if set1_features is None or set2_features is None:
96
  return "Could not extract features for one or both sets (check image validity and count)."
97
+ if set1_features.dim() == 0 or set2_features.dim() == 0 or set1_features.numel() == 0 or set2_features.numel() == 0:
98
  return "Feature extraction resulted in empty tensors."
99
 
100
+ fid_value = fid_metric(set1_features, set2_features)
 
101
  return round(fid_value.item(), 3)
102
  except Exception as e:
103
  print(f"FID calculation error: {e}")
 
105
 
106
  # --- Helper Functions ---
107
  def pil_to_tensor_chw_01(img_pil_rgb):
108
+ transform = T.Compose([T.ToTensor()])
 
109
  return transform(img_pil_rgb)
110
 
111
  def create_thumbnail_base64(img_pil_rgb, size=THUMBNAIL_SIZE):
 
112
  img_copy = img_pil_rgb.copy()
113
  img_copy.thumbnail(size)
114
  buffered = io.BytesIO()
 
119
  # --- Main Processing Functions for Gradio ---
120
 
121
  def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progress(track_tqdm=True)):
 
122
  if not uploaded_file_list:
123
+ return pd.DataFrame(), "Please upload images first."
124
 
 
125
  if len(uploaded_file_list) > MAX_IMAGES_PER_BATCH:
126
  status_message = f"Too many images ({len(uploaded_file_list)}). Processing first {MAX_IMAGES_PER_BATCH} images."
127
  uploaded_file_list = uploaded_file_list[:MAX_IMAGES_PER_BATCH]
 
129
  status_message = f"Processing {len(uploaded_file_list)} images..."
130
 
131
  progress(0, desc=status_message)
 
132
  results_data = []
 
 
133
 
134
  for i, file_obj in enumerate(uploaded_file_list):
135
  try:
 
136
  file_path = file_obj.name
137
  base_filename = os.path.basename(file_path)
 
138
  img_pil_rgb = Image.open(file_path).convert("RGB")
139
 
 
140
  img_tensor_chw_01 = pil_to_tensor_chw_01(img_pil_rgb)
141
  brisque_val = get_brisque_score(img_tensor_chw_01)
 
 
142
  niqe_val = get_niqe_score(img_pil_rgb)
143
  musiq_nr_val = get_musiq_nr_score(img_pil_rgb)
 
 
144
  thumbnail_b64 = create_thumbnail_base64(img_pil_rgb)
145
  preview_html = f'<img src="{thumbnail_b64}" alt="{base_filename}">'
146
 
 
164
 
165
  df_results = pd.DataFrame(results_data)
166
  status_message += f"\nPer-image metrics calculated for {len(results_data)} images."
167
+ return df_results, status_message
 
 
 
 
 
 
 
 
168
 
169
 
170
  def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progress(track_tqdm=True)):
 
171
  if not set1_file_list or not set2_file_list:
172
  return "Please upload files for both Set 1 and Set 2."
173
 
 
 
 
174
  set1_dir = tempfile.mkdtemp(prefix="fid_set1_")
175
  set2_dir = tempfile.mkdtemp(prefix="fid_set2_")
 
176
  fid_result_text = "Starting FID calculation..."
177
  progress(0.1, desc="Preparing image sets for FID...")
178
 
179
  try:
 
180
  for i, file_obj in enumerate(set1_file_list):
181
  shutil.copy(file_obj.name, os.path.join(set1_dir, os.path.basename(file_obj.name)))
182
  progress(0.1 + 0.2 * (i / len(set1_file_list)), desc=f"Copying Set 1: {os.path.basename(file_obj.name)}")
 
199
  except Exception as e:
200
  fid_result_text = f"Error during FID preparation or calculation: {str(e)}"
201
  finally:
 
202
  if os.path.exists(set1_dir): shutil.rmtree(set1_dir)
203
  if os.path.exists(set2_dir): shutil.rmtree(set2_dir)
 
204
  return fid_result_text
205
 
206
  # --- Gradio UI Definition ---
 
224
  image_upload_input = gr.Files(
225
  label=f"Upload Images (max {MAX_IMAGES_PER_BATCH}, .png, .jpg, .jpeg, .bmp, .webp)",
226
  file_count="multiple",
227
+ type="filepath"
228
  )
229
 
230
  evaluate_button_main = gr.Button("πŸ–ΌοΈ Evaluate Uploaded Images", variant="primary")
 
234
 
235
  gr.Markdown("### πŸ–ΌοΈ Per-Image Evaluation Results")
236
  gr.Markdown("Click column headers to sort. Previews are thumbnails.")
 
237
  results_table_output = gr.DataFrame(
238
  headers=["Preview", "Filename", "BRISQUE (PIQ) (↓)", "NIQE (IQA-PyTorch) (↓)", "MUSIQ-NR (IQA-PyTorch) (↑)"],
239
  datatype=["html", "str", "number", "number", "number"],
240
+ interactive=False,
241
+ wrap=True,
242
+ row_count=(15, "paginate") # MODIFIED: Replaced max_rows and overflow_row_behaviour
 
243
  )
244
 
245
  with gr.TabItem("↔️ Calculate FID (Set vs. Set)"):
 
255
  fid_calculate_button = gr.Button("πŸ”— Calculate FID between Set 1 and Set 2", variant="primary")
256
  fid_result_output = gr.Textbox(label="πŸ“ˆ FID Result", interactive=False, lines=2)
257
 
 
258
  evaluate_button_main.click(
259
  fn=process_images_for_individual_scores,
260
  inputs=[image_upload_input],
261
+ outputs=[results_table_output, status_output_main]
262
  )
263
 
264
  fid_calculate_button.click(
 
275
  torchvision
276
  Pillow
277
  numpy
278
+ piq>=0.8.0
279
+ iqa-pytorch==0.2.1 # PINNED VERSION
280
+ timm
281
+ scikit-image
282
  pandas
283
  """
284
 
285
  if __name__ == "__main__":
286
+ if piq is None:
287
+ print("\n\nWARNING: PIQ library is missing.")
288
+ print("Please install it: pip install piq\n\n")
289
+ if IQA is None:
290
+ print("\n\nERROR: IQA-PyTorch library import failed or it's missing.")
291
+ print("Please ensure it's installed correctly (e.g., pip install iqa-pytorch==0.2.1) and check for import errors during startup.\n\n")
292
 
293
+ demo.launch(debug=True)