VOIDER commited on
Commit
bdc8307
Β·
verified Β·
1 Parent(s): d995112

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -96
app.py CHANGED
@@ -19,91 +19,81 @@ except ImportError:
19
 
20
  # IQA-PyTorch imports
21
  try:
 
22
  from iqa_pytorch import IQA
23
- # Available models in IQA-PyTorch (examples for NR):
24
- # "MUSIQ-L2N-lessons", "MUSIQ-Koniq-NSR", "MUSIQ-SpAq-NSR"
25
- # "BRISQUE-PyTorch", "NIQE-PyTorch"
26
- # "NIMA-VGG16-estimate", "NIMA-MobileNet-estimate" (Aesthetic)
27
- except ImportError:
28
- print("ERROR: IQA-PyTorch library import failed. Some metrics (NIQE, MUSIQ-NR) will be unavailable. Check installation.")
29
  IQA = None
30
 
 
31
  # --- Configuration ---
32
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
  MAX_IMAGES_PER_BATCH = 100
34
  THUMBNAIL_SIZE = (64, 64) # (width, height) for preview
35
 
 
 
 
 
 
 
 
 
 
 
36
  # --- Metric Functions ---
37
 
38
  def get_brisque_score(img_tensor_chw_01):
39
- """Calculates BRISQUE score using PIQ. Expects a (C, H, W) tensor, range [0, 1]."""
40
  if piq is None: return "N/A (PIQ missing)"
41
  try:
42
  if img_tensor_chw_01.ndim == 3:
43
  img_tensor_bchw_01 = img_tensor_chw_01.unsqueeze(0)
44
  else:
45
  img_tensor_bchw_01 = img_tensor_chw_01
46
-
47
  if img_tensor_bchw_01.shape[1] == 1:
48
  img_tensor_bchw_01 = img_tensor_bchw_01.repeat(1, 3, 1, 1)
49
-
50
  brisque_loss = piq.brisque(img_tensor_bchw_01.to(DEVICE), data_range=1.)
51
  return round(brisque_loss.item(), 3)
52
- except Exception as e:
53
- return f"Error"
54
-
55
 
56
  def get_niqe_score(img_pil_rgb):
57
- """Calculates NIQE score using IQA-PyTorch. Expects a PIL RGB image."""
58
  if IQA is None: return "N/A (IQA missing)"
59
  try:
60
- niqe_metric = IQA(libs='NIQE-PyTorch', device=DEVICE) # NIQE is No-Reference
61
  score = niqe_metric(img_pil_rgb)
62
  return round(score.item(), 3)
63
- except Exception as e:
64
- return f"Error"
65
 
66
  def get_musiq_nr_score(img_pil_rgb):
67
- """Calculates No-Reference MUSIQ score using IQA-PyTorch. Expects a PIL RGB image."""
68
  if IQA is None: return "N/A (IQA missing)"
69
  try:
70
- musiq_metric = IQA(libs='MUSIQ-L2N-lessons', device=DEVICE)
71
  score = musiq_metric(img_pil_rgb)
72
  return round(score.item(), 3)
73
- except Exception as e:
74
- return f"Error"
75
-
76
 
77
  def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder):
78
- """Calculates FID between two folders of images using PIQ."""
79
  if piq is None: return "N/A (PIQ missing)"
80
  try:
81
- set1_files = [os.path.join(path_to_set1_folder, f) for f in os.listdir(path_to_set1_folder)
82
- if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
83
- set2_files = [os.path.join(path_to_set2_folder, f) for f in os.listdir(path_to_set2_folder)
84
- if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
85
-
86
- if not set1_files or not set2_files:
87
- return "One or both sets have no valid image files."
88
- if len(set1_files) < 2 or len(set2_files) < 2:
89
- return f"FID needs at least 2 images per set. Found: Set1={len(set1_files)}, Set2={len(set2_files)}."
90
-
91
  fid_metric = piq.FID()
92
  set1_features = fid_metric.compute_feats(set1_files, device=DEVICE)
93
  set2_features = fid_metric.compute_feats(set2_files, device=DEVICE)
94
-
95
- if set1_features is None or set2_features is None:
96
- return "Could not extract features for one or both sets (check image validity and count)."
97
- if set1_features.dim() == 0 or set2_features.dim() == 0 or set1_features.numel() == 0 or set2_features.numel() == 0:
98
- return "Feature extraction resulted in empty tensors."
99
-
100
  fid_value = fid_metric(set1_features, set2_features)
101
  return round(fid_value.item(), 3)
102
  except Exception as e:
103
  print(f"FID calculation error: {e}")
104
  return f"FID Error: {str(e)[:100]}"
105
 
106
- # --- Helper Functions ---
107
  def pil_to_tensor_chw_01(img_pil_rgb):
108
  transform = T.Compose([T.ToTensor()])
109
  return transform(img_pil_rgb)
@@ -116,8 +106,39 @@ def create_thumbnail_base64(img_pil_rgb, size=THUMBNAIL_SIZE):
116
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
117
  return f"data:image/png;base64,{img_str}"
118
 
119
- # --- Main Processing Functions for Gradio ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
121
  def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progress(track_tqdm=True)):
122
  if not uploaded_file_list:
123
  return pd.DataFrame(), "Please upload images first."
@@ -132,6 +153,7 @@ def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progres
132
  results_data = []
133
 
134
  for i, file_obj in enumerate(uploaded_file_list):
 
135
  try:
136
  file_path = file_obj.name
137
  base_filename = os.path.basename(file_path)
@@ -141,6 +163,9 @@ def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progres
141
  brisque_val = get_brisque_score(img_tensor_chw_01)
142
  niqe_val = get_niqe_score(img_pil_rgb)
143
  musiq_nr_val = get_musiq_nr_score(img_pil_rgb)
 
 
 
144
  thumbnail_b64 = create_thumbnail_base64(img_pil_rgb)
145
  preview_html = f'<img src="{thumbnail_b64}" alt="{base_filename}">'
146
 
@@ -150,15 +175,15 @@ def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progres
150
  "BRISQUE (PIQ) (↓)": brisque_val,
151
  "NIQE (IQA-PyTorch) (↓)": niqe_val,
152
  "MUSIQ-NR (IQA-PyTorch) (↑)": musiq_nr_val,
 
153
  })
154
  except Exception as e:
155
- try: base_filename = os.path.basename(file_obj.name if hasattr(file_obj, 'name') else str(file_obj))
156
- except: base_filename = "Unknown File"
157
  results_data.append({
158
  "Preview": "Error processing", "Filename": base_filename,
159
- "BRISQUE (PIQ) (↓)": f"Load Err",
160
  "NIQE (IQA-PyTorch) (↓)": "N/A",
161
  "MUSIQ-NR (IQA-PyTorch) (↑)": "N/A",
 
162
  })
163
  progress((i + 1) / len(uploaded_file_list), desc=f"Processing {base_filename}")
164
 
@@ -166,38 +191,27 @@ def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progres
166
  status_message += f"\nPer-image metrics calculated for {len(results_data)} images."
167
  return df_results, status_message
168
 
169
-
170
  def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progress(track_tqdm=True)):
171
  if not set1_file_list or not set2_file_list:
172
  return "Please upload files for both Set 1 and Set 2."
173
-
174
  set1_dir = tempfile.mkdtemp(prefix="fid_set1_")
175
  set2_dir = tempfile.mkdtemp(prefix="fid_set2_")
176
  fid_result_text = "Starting FID calculation..."
177
  progress(0.1, desc="Preparing image sets for FID...")
178
-
179
  try:
180
  for i, file_obj in enumerate(set1_file_list):
181
  shutil.copy(file_obj.name, os.path.join(set1_dir, os.path.basename(file_obj.name)))
182
  progress(0.1 + 0.2 * (i / len(set1_file_list)), desc=f"Copying Set 1: {os.path.basename(file_obj.name)}")
183
-
184
  for i, file_obj in enumerate(set2_file_list):
185
  shutil.copy(file_obj.name, os.path.join(set2_dir, os.path.basename(file_obj.name)))
186
  progress(0.3 + 0.2 * (i / len(set2_file_list)), desc=f"Copying Set 2: {os.path.basename(file_obj.name)}")
187
-
188
- num_set1 = len(os.listdir(set1_dir))
189
- num_set2 = len(os.listdir(set2_dir))
190
-
191
- if num_set1 == 0 or num_set2 == 0:
192
- return f"FID Error: One or both sets are empty after copying. Set 1: {num_set1}, Set 2: {num_set2}."
193
-
194
  progress(0.5, desc=f"Calculating FID between Set 1 ({num_set1} images) and Set 2 ({num_set2} images)...")
195
  fid_score = get_fid_score_piq_folders(set1_dir, set2_dir)
196
  progress(1, desc="FID calculation complete.")
197
  fid_result_text = f"FID (PIQ) between Set 1 ({num_set1} images) and Set 2 ({num_set2} images): {fid_score}"
198
-
199
- except Exception as e:
200
- fid_result_text = f"Error during FID preparation or calculation: {str(e)}"
201
  finally:
202
  if os.path.exists(set1_dir): shutil.rmtree(set1_dir)
203
  if os.path.exists(set2_dir): shutil.rmtree(set2_dir)
@@ -215,60 +229,42 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css_custom) as demo:
215
  **Objective:** Automated evaluation and comparison of image quality from different model versions.
216
  Utilizes `PIQ` and `IQA-PyTorch` libraries. Runs on **{DEVICE}**.
217
  (↓) means lower is better, (↑) means higher is better.
 
218
  """)
219
 
220
  with gr.Tabs():
221
  with gr.TabItem("Per-Image Quality Evaluation"):
222
- gr.Markdown(f"Upload a batch of images (up to **{MAX_IMAGES_PER_BATCH}**) to get individual quality scores. Images are processed in the browser's session.")
223
-
224
- image_upload_input = gr.Files(
225
- label=f"Upload Images (max {MAX_IMAGES_PER_BATCH}, .png, .jpg, .jpeg, .bmp, .webp)",
226
- file_count="multiple",
227
- type="filepath"
228
- )
229
-
230
  evaluate_button_main = gr.Button("πŸ–ΌοΈ Evaluate Uploaded Images", variant="primary")
231
-
232
  gr.Markdown("---")
233
  status_output_main = gr.Textbox(label="πŸ“Š Evaluation Status", interactive=False, lines=2)
234
-
235
  gr.Markdown("### πŸ–ΌοΈ Per-Image Evaluation Results")
236
  gr.Markdown("Click column headers to sort. Previews are thumbnails.")
237
  results_table_output = gr.DataFrame(
238
- headers=["Preview", "Filename", "BRISQUE (PIQ) (↓)", "NIQE (IQA-PyTorch) (↓)", "MUSIQ-NR (IQA-PyTorch) (↑)"],
239
- datatype=["html", "str", "number", "number", "number"],
240
  interactive=False,
241
  wrap=True,
242
- row_count=(15, "paginate") # MODIFIED: Replaced max_rows and overflow_row_behaviour
243
  )
244
 
245
  with gr.TabItem("↔️ Calculate FID (Set vs. Set)"):
246
  gr.Markdown("""
247
  Calculate FrΓ©chet Inception Distance (FID) between two sets of images.
248
- FID measures the similarity of two image distributions (e.g., generated vs. real, or version A vs. version B).
249
- **Lower FID scores are better**, indicating more similarity.
250
  """)
251
  with gr.Row():
252
- fid_set1_upload = gr.Files(label="Upload Images for Set 1 (.png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath")
253
- fid_set2_upload = gr.Files(label="Upload Images for Set 2 (.png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath")
254
-
255
  fid_calculate_button = gr.Button("πŸ”— Calculate FID between Set 1 and Set 2", variant="primary")
256
  fid_result_output = gr.Textbox(label="πŸ“ˆ FID Result", interactive=False, lines=2)
257
 
258
- evaluate_button_main.click(
259
- fn=process_images_for_individual_scores,
260
- inputs=[image_upload_input],
261
- outputs=[results_table_output, status_output_main]
262
- )
263
 
264
- fid_calculate_button.click(
265
- fn=process_fid_for_two_sets,
266
- inputs=[fid_set1_upload, fid_set2_upload],
267
- outputs=[fid_result_output]
268
- )
269
-
270
- # --- For Hugging Face Spaces: requirements.txt ---
271
- # Ensure this content is in your 'requirements.txt' file in the HF Space:
272
  """
273
  gradio
274
  torch
@@ -276,18 +272,14 @@ torchvision
276
  Pillow
277
  numpy
278
  piq>=0.8.0
279
- iqa-pytorch==0.2.1 # PINNED VERSION
280
  timm
281
  scikit-image
282
  pandas
 
283
  """
284
 
285
  if __name__ == "__main__":
286
- if piq is None:
287
- print("\n\nWARNING: PIQ library is missing.")
288
- print("Please install it: pip install piq\n\n")
289
- if IQA is None:
290
- print("\n\nERROR: IQA-PyTorch library import failed or it's missing.")
291
- print("Please ensure it's installed correctly (e.g., pip install iqa-pytorch==0.2.1) and check for import errors during startup.\n\n")
292
-
293
  demo.launch(debug=True)
 
19
 
20
  # IQA-PyTorch imports
21
  try:
22
+ # This import needs to succeed for NIQE and MUSIQ
23
  from iqa_pytorch import IQA
24
+ except ImportError as e:
25
+ print(f"ERROR: IQA-PyTorch library import failed: {e}. Some metrics (NIQE, MUSIQ-NR) will be unavailable. Check installation and dependencies (like kornia).")
26
+ IQA = None
27
+ except Exception as e:
28
+ print(f"ERROR: An unexpected error occurred during IQA-PyTorch import: {e}")
 
29
  IQA = None
30
 
31
+
32
  # --- Configuration ---
33
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
  MAX_IMAGES_PER_BATCH = 100
35
  THUMBNAIL_SIZE = (64, 64) # (width, height) for preview
36
 
37
+ # --- Metric Normalization Parameters (Approximate typical ranges) ---
38
+ # For "lower is better" metrics, score is (max_val - current_val) / (max_val - min_val)
39
+ # For "higher is better" metrics, score is (current_val - min_val) / (max_val - min_val)
40
+ # These are heuristics and can be adjusted.
41
+ METRIC_RANGES = {
42
+ "brisque": {"min": 0, "max": 120, "lower_is_better": True}, # Typical BRISQUE range
43
+ "niqe": {"min": 0, "max": 12, "lower_is_better": True}, # Typical NIQE range
44
+ "musiq_nr": {"min": 10, "max": 90, "lower_is_better": False} # Example MUSIQ range
45
+ }
46
+
47
  # --- Metric Functions ---
48
 
49
  def get_brisque_score(img_tensor_chw_01):
 
50
  if piq is None: return "N/A (PIQ missing)"
51
  try:
52
  if img_tensor_chw_01.ndim == 3:
53
  img_tensor_bchw_01 = img_tensor_chw_01.unsqueeze(0)
54
  else:
55
  img_tensor_bchw_01 = img_tensor_chw_01
 
56
  if img_tensor_bchw_01.shape[1] == 1:
57
  img_tensor_bchw_01 = img_tensor_bchw_01.repeat(1, 3, 1, 1)
 
58
  brisque_loss = piq.brisque(img_tensor_bchw_01.to(DEVICE), data_range=1.)
59
  return round(brisque_loss.item(), 3)
60
+ except Exception: return "Error"
 
 
61
 
62
  def get_niqe_score(img_pil_rgb):
 
63
  if IQA is None: return "N/A (IQA missing)"
64
  try:
65
+ niqe_metric = IQA(libs='NIQE-PyTorch', device=DEVICE)
66
  score = niqe_metric(img_pil_rgb)
67
  return round(score.item(), 3)
68
+ except Exception: return "Error"
 
69
 
70
  def get_musiq_nr_score(img_pil_rgb):
 
71
  if IQA is None: return "N/A (IQA missing)"
72
  try:
73
+ musiq_metric = IQA(libs='MUSIQ-L2N-lessons', device=DEVICE) # Example, could be other MUSIQ variants
74
  score = musiq_metric(img_pil_rgb)
75
  return round(score.item(), 3)
76
+ except Exception: return "Error"
 
 
77
 
78
  def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder):
 
79
  if piq is None: return "N/A (PIQ missing)"
80
  try:
81
+ set1_files = [os.path.join(path_to_set1_folder, f) for f in os.listdir(path_to_set1_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
82
+ set2_files = [os.path.join(path_to_set2_folder, f) for f in os.listdir(path_to_set2_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
83
+ if not set1_files or not set2_files: return "One or both sets have no valid image files."
84
+ if len(set1_files) < 2 or len(set2_files) < 2: return f"FID needs at least 2 images per set. Found: Set1={len(set1_files)}, Set2={len(set2_files)}."
 
 
 
 
 
 
85
  fid_metric = piq.FID()
86
  set1_features = fid_metric.compute_feats(set1_files, device=DEVICE)
87
  set2_features = fid_metric.compute_feats(set2_files, device=DEVICE)
88
+ if set1_features is None or set2_features is None: return "Could not extract features for one or both sets."
89
+ if set1_features.dim() == 0 or set2_features.dim() == 0 or set1_features.numel() == 0 or set2_features.numel() == 0: return "Feature extraction resulted in empty tensors."
 
 
 
 
90
  fid_value = fid_metric(set1_features, set2_features)
91
  return round(fid_value.item(), 3)
92
  except Exception as e:
93
  print(f"FID calculation error: {e}")
94
  return f"FID Error: {str(e)[:100]}"
95
 
96
+ # --- Helper & Final Score Calculation ---
97
  def pil_to_tensor_chw_01(img_pil_rgb):
98
  transform = T.Compose([T.ToTensor()])
99
  return transform(img_pil_rgb)
 
106
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
107
  return f"data:image/png;base64,{img_str}"
108
 
109
+ def calculate_final_score(brisque_val, niqe_val, musiq_nr_val):
110
+ normalized_scores = []
111
+
112
+ # BRISQUE
113
+ if isinstance(brisque_val, (float, int)):
114
+ cfg = METRIC_RANGES["brisque"]
115
+ val = max(cfg["min"], min(cfg["max"], brisque_val)) # Clip
116
+ norm_score = (cfg["max"] - val) / (cfg["max"] - cfg["min"]) if cfg["lower_is_better"] else (val - cfg["min"]) / (cfg["max"] - cfg["min"])
117
+ normalized_scores.append(norm_score)
118
+
119
+ # NIQE
120
+ if isinstance(niqe_val, (float, int)):
121
+ cfg = METRIC_RANGES["niqe"]
122
+ val = max(cfg["min"], min(cfg["max"], niqe_val)) # Clip
123
+ norm_score = (cfg["max"] - val) / (cfg["max"] - cfg["min"]) if cfg["lower_is_better"] else (val - cfg["min"]) / (cfg["max"] - cfg["min"])
124
+ normalized_scores.append(norm_score)
125
+
126
+ # MUSIQ-NR
127
+ if isinstance(musiq_nr_val, (float, int)):
128
+ cfg = METRIC_RANGES["musiq_nr"]
129
+ val = max(cfg["min"], min(cfg["max"], musiq_nr_val)) # Clip
130
+ norm_score = (cfg["max"] - val) / (cfg["max"] - cfg["min"]) if cfg["lower_is_better"] else (val - cfg["min"]) / (cfg["max"] - cfg["min"])
131
+ normalized_scores.append(norm_score)
132
+
133
+ if not normalized_scores:
134
+ return "N/A"
135
+
136
+ # Average of normalized scores, then scale to 0-10
137
+ final_score_0_10 = (sum(normalized_scores) / len(normalized_scores)) * 10.0
138
+ return round(final_score_0_10, 4)
139
+
140
 
141
+ # --- Main Processing Functions for Gradio ---
142
  def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progress(track_tqdm=True)):
143
  if not uploaded_file_list:
144
  return pd.DataFrame(), "Please upload images first."
 
153
  results_data = []
154
 
155
  for i, file_obj in enumerate(uploaded_file_list):
156
+ base_filename = "Unknown File"
157
  try:
158
  file_path = file_obj.name
159
  base_filename = os.path.basename(file_path)
 
163
  brisque_val = get_brisque_score(img_tensor_chw_01)
164
  niqe_val = get_niqe_score(img_pil_rgb)
165
  musiq_nr_val = get_musiq_nr_score(img_pil_rgb)
166
+
167
+ final_score = calculate_final_score(brisque_val, niqe_val, musiq_nr_val)
168
+
169
  thumbnail_b64 = create_thumbnail_base64(img_pil_rgb)
170
  preview_html = f'<img src="{thumbnail_b64}" alt="{base_filename}">'
171
 
 
175
  "BRISQUE (PIQ) (↓)": brisque_val,
176
  "NIQE (IQA-PyTorch) (↓)": niqe_val,
177
  "MUSIQ-NR (IQA-PyTorch) (↑)": musiq_nr_val,
178
+ "Final Score (0-10) (↑)": final_score,
179
  })
180
  except Exception as e:
 
 
181
  results_data.append({
182
  "Preview": "Error processing", "Filename": base_filename,
183
+ "BRISQUE (PIQ) (↓)": f"Load Err: {str(e)[:30]}",
184
  "NIQE (IQA-PyTorch) (↓)": "N/A",
185
  "MUSIQ-NR (IQA-PyTorch) (↑)": "N/A",
186
+ "Final Score (0-10) (↑)": "N/A",
187
  })
188
  progress((i + 1) / len(uploaded_file_list), desc=f"Processing {base_filename}")
189
 
 
191
  status_message += f"\nPer-image metrics calculated for {len(results_data)} images."
192
  return df_results, status_message
193
 
 
194
  def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progress(track_tqdm=True)):
195
  if not set1_file_list or not set2_file_list:
196
  return "Please upload files for both Set 1 and Set 2."
 
197
  set1_dir = tempfile.mkdtemp(prefix="fid_set1_")
198
  set2_dir = tempfile.mkdtemp(prefix="fid_set2_")
199
  fid_result_text = "Starting FID calculation..."
200
  progress(0.1, desc="Preparing image sets for FID...")
 
201
  try:
202
  for i, file_obj in enumerate(set1_file_list):
203
  shutil.copy(file_obj.name, os.path.join(set1_dir, os.path.basename(file_obj.name)))
204
  progress(0.1 + 0.2 * (i / len(set1_file_list)), desc=f"Copying Set 1: {os.path.basename(file_obj.name)}")
 
205
  for i, file_obj in enumerate(set2_file_list):
206
  shutil.copy(file_obj.name, os.path.join(set2_dir, os.path.basename(file_obj.name)))
207
  progress(0.3 + 0.2 * (i / len(set2_file_list)), desc=f"Copying Set 2: {os.path.basename(file_obj.name)}")
208
+ num_set1 = len(os.listdir(set1_dir)); num_set2 = len(os.listdir(set2_dir))
209
+ if num_set1 == 0 or num_set2 == 0: return f"FID Error: One or both sets are empty after copying. Set 1: {num_set1}, Set 2: {num_set2}."
 
 
 
 
 
210
  progress(0.5, desc=f"Calculating FID between Set 1 ({num_set1} images) and Set 2 ({num_set2} images)...")
211
  fid_score = get_fid_score_piq_folders(set1_dir, set2_dir)
212
  progress(1, desc="FID calculation complete.")
213
  fid_result_text = f"FID (PIQ) between Set 1 ({num_set1} images) and Set 2 ({num_set2} images): {fid_score}"
214
+ except Exception as e: fid_result_text = f"Error during FID preparation or calculation: {str(e)}"
 
 
215
  finally:
216
  if os.path.exists(set1_dir): shutil.rmtree(set1_dir)
217
  if os.path.exists(set2_dir): shutil.rmtree(set2_dir)
 
229
  **Objective:** Automated evaluation and comparison of image quality from different model versions.
230
  Utilizes `PIQ` and `IQA-PyTorch` libraries. Runs on **{DEVICE}**.
231
  (↓) means lower is better, (↑) means higher is better.
232
+ Final Score is a heuristic combination of available metrics (0-10, higher is better).
233
  """)
234
 
235
  with gr.Tabs():
236
  with gr.TabItem("Per-Image Quality Evaluation"):
237
+ gr.Markdown(f"Upload a batch of images (up to **{MAX_IMAGES_PER_BATCH}**) to get individual quality scores.")
238
+ image_upload_input = gr.Files(label=f"Upload Images (max {MAX_IMAGES_PER_BATCH}, .png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath")
 
 
 
 
 
 
239
  evaluate_button_main = gr.Button("πŸ–ΌοΈ Evaluate Uploaded Images", variant="primary")
 
240
  gr.Markdown("---")
241
  status_output_main = gr.Textbox(label="πŸ“Š Evaluation Status", interactive=False, lines=2)
 
242
  gr.Markdown("### πŸ–ΌοΈ Per-Image Evaluation Results")
243
  gr.Markdown("Click column headers to sort. Previews are thumbnails.")
244
  results_table_output = gr.DataFrame(
245
+ headers=["Preview", "Filename", "BRISQUE (PIQ) (↓)", "NIQE (IQA-PyTorch) (↓)", "MUSIQ-NR (IQA-PyTorch) (↑)", "Final Score (0-10) (↑)"],
246
+ datatype=["html", "str", "number", "number", "number", "number"], # Added "number" for Final Score
247
  interactive=False,
248
  wrap=True,
249
+ row_count=(15, "paginate")
250
  )
251
 
252
  with gr.TabItem("↔️ Calculate FID (Set vs. Set)"):
253
  gr.Markdown("""
254
  Calculate FrΓ©chet Inception Distance (FID) between two sets of images.
255
+ FID measures the similarity of two image distributions. **Lower FID scores are better**.
 
256
  """)
257
  with gr.Row():
258
+ fid_set1_upload = gr.Files(label="Upload Images for Set 1", file_count="multiple", type="filepath")
259
+ fid_set2_upload = gr.Files(label="Upload Images for Set 2", file_count="multiple", type="filepath")
 
260
  fid_calculate_button = gr.Button("πŸ”— Calculate FID between Set 1 and Set 2", variant="primary")
261
  fid_result_output = gr.Textbox(label="πŸ“ˆ FID Result", interactive=False, lines=2)
262
 
263
+ evaluate_button_main.click(fn=process_images_for_individual_scores, inputs=[image_upload_input], outputs=[results_table_output, status_output_main])
264
+ fid_calculate_button.click(fn=process_fid_for_two_sets, inputs=[fid_set1_upload, fid_set2_upload], outputs=[fid_result_output])
 
 
 
265
 
266
+ # --- For Hugging Face Spaces ---
267
+ # Ensure 'requirements.txt' includes:
 
 
 
 
 
 
268
  """
269
  gradio
270
  torch
 
272
  Pillow
273
  numpy
274
  piq>=0.8.0
275
+ iqa-pytorch==0.1
276
  timm
277
  scikit-image
278
  pandas
279
+ kornia
280
  """
281
 
282
  if __name__ == "__main__":
283
+ if piq is None: print("\nWARNING: PIQ library is missing. pip install piq\n")
284
+ if IQA is None: print("\nERROR: IQA-PyTorch library import failed. pip install iqa-pytorch==0.1 kornia\n")
 
 
 
 
 
285
  demo.launch(debug=True)