VOIDER commited on
Commit
8c2a1e0
·
verified ·
1 Parent(s): 67b6aee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+ import numpy as np
7
+ import io
8
+ import base64
9
+ import os
10
+ import shutil
11
+ import tempfile
12
+
13
+ # PIQ imports
14
+ try:
15
+ import piq
16
+ except ImportError:
17
+ print("Warning: PIQ library not found. Some metrics (BRISQUE, FID) will be unavailable.")
18
+ piq = None
19
+
20
+ # IQA-PyTorch imports
21
+ try:
22
+ from iqa_pytorch import IQA
23
+ # Available models in IQA-PyTorch (examples for NR):
24
+ # "MUSIQ-L2N-lessons", "MUSIQ-Koniq-NSR", "MUSIQ-SpAq-NSR"
25
+ # "BRISQUE-PyTorch", "NIQE-PyTorch"
26
+ # "NIMA-VGG16-estimate", "NIMA-MobileNet-estimate" (Aesthetic)
27
+ except ImportError:
28
+ print("Warning: IQA-PyTorch library not found. Some metrics (NIQE, MUSIQ-NR) will be unavailable.")
29
+ IQA = None
30
+
31
+ # --- Configuration ---
32
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ MAX_IMAGES_PER_BATCH = 100
34
+ THUMBNAIL_SIZE = (64, 64) # (width, height) for preview
35
+
36
+ # --- Metric Functions ---
37
+
38
+ def get_brisque_score(img_tensor_chw_01):
39
+ """Calculates BRISQUE score using PIQ. Expects a (C, H, W) tensor, range [0, 1]."""
40
+ if piq is None: return "N/A (PIQ missing)"
41
+ try:
42
+ # Ensure tensor is (B, C, H, W) for piq.brisque
43
+ if img_tensor_chw_01.ndim == 3:
44
+ img_tensor_bchw_01 = img_tensor_chw_01.unsqueeze(0)
45
+ else: # Already has batch dim or incorrect dims
46
+ img_tensor_bchw_01 = img_tensor_chw_01
47
+
48
+ # Ensure 3 channels if it's grayscale by repeating
49
+ if img_tensor_bchw_01.shape[1] == 1:
50
+ img_tensor_bchw_01 = img_tensor_bchw_01.repeat(1, 3, 1, 1)
51
+
52
+ brisque_loss = piq.brisque(img_tensor_bchw_01.to(DEVICE), data_range=1.)
53
+ return round(brisque_loss.item(), 3)
54
+ except Exception as e:
55
+ # print(f"BRISQUE Error: {e} for tensor shape {img_tensor_chw_01.shape}")
56
+ return f"Error"
57
+
58
+
59
+ def get_niqe_score(img_pil_rgb):
60
+ """Calculates NIQE score using IQA-PyTorch. Expects a PIL RGB image."""
61
+ if IQA is None: return "N/A (IQA missing)"
62
+ try:
63
+ niqe_metric = IQA(libs='NIQE-PyTorch', device=DEVICE) # NIQE is No-Reference
64
+ score = niqe_metric(img_pil_rgb)
65
+ return round(score.item(), 3)
66
+ except Exception as e:
67
+ # print(f"NIQE Error: {e}")
68
+ return f"Error"
69
+
70
+ def get_musiq_nr_score(img_pil_rgb):
71
+ """Calculates No-Reference MUSIQ score using IQA-PyTorch. Expects a PIL RGB image."""
72
+ if IQA is None: return "N/A (IQA missing)"
73
+ try:
74
+ # Using MUSIQ-L2N-lessons as an example NR model from IQA-PyTorch
75
+ # Other options: "MUSIQ-Koniq-NSR", "MUSIQ-SpAq-NSR"
76
+ musiq_metric = IQA(libs='MUSIQ-L2N-lessons', device=DEVICE)
77
+ score = musiq_metric(img_pil_rgb)
78
+ return round(score.item(), 3)
79
+ except Exception as e:
80
+ # print(f"MUSIQ-NR Error: {e}")
81
+ return f"Error"
82
+
83
+
84
+ def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder):
85
+ """Calculates FID between two folders of images using PIQ."""
86
+ if piq is None: return "N/A (PIQ missing)"
87
+ try:
88
+ # List image files in folders
89
+ set1_files = [os.path.join(path_to_set1_folder, f) for f in os.listdir(path_to_set1_folder)
90
+ if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
91
+ set2_files = [os.path.join(path_to_set2_folder, f) for f in os.listdir(path_to_set2_folder)
92
+ if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
93
+
94
+ if not set1_files or not set2_files:
95
+ return "One or both sets have no valid image files."
96
+ if len(set1_files) < 2 or len(set2_files) < 2: # FID usually needs more, but PIQ might handle small N. Min 2 to compute stats.
97
+ return f"FID needs at least 2 images per set. Found: Set1={len(set1_files)}, Set2={len(set2_files)}."
98
+
99
+ fid_metric = piq.FID()
100
+ # compute_feats expects a list of image paths
101
+ set1_features = fid_metric.compute_feats(set1_files, device=DEVICE)
102
+ set2_features = fid_metric.compute_feats(set2_files, device=DEVICE)
103
+
104
+ if set1_features is None or set2_features is None:
105
+ return "Could not extract features for one or both sets (check image validity and count)."
106
+ if set1_features.dim() == 0 or set2_features.dim() == 0 or set1_features.numel() == 0 or set2_features.numel() == 0: # Handle empty tensors
107
+ return "Feature extraction resulted in empty tensors."
108
+
109
+
110
+ fid_value = fid_metric(set1_features, set2_features) # Pass tensors directly
111
+ return round(fid_value.item(), 3)
112
+ except Exception as e:
113
+ print(f"FID calculation error: {e}")
114
+ return f"FID Error: {str(e)[:100]}"
115
+
116
+ # --- Helper Functions ---
117
+ def pil_to_tensor_chw_01(img_pil_rgb):
118
+ """Converts PIL RGB image to PyTorch CHW tensor [0,1]."""
119
+ transform = T.Compose([T.ToTensor()]) # Converts PIL [0,255] to Tensor [0,1] C,H,W
120
+ return transform(img_pil_rgb)
121
+
122
+ def create_thumbnail_base64(img_pil_rgb, size=THUMBNAIL_SIZE):
123
+ """Creates a base64 encoded PNG thumbnail string from a PIL image."""
124
+ img_copy = img_pil_rgb.copy()
125
+ img_copy.thumbnail(size)
126
+ buffered = io.BytesIO()
127
+ img_copy.save(buffered, format="PNG")
128
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
129
+ return f"data:image/png;base64,{img_str}"
130
+
131
+ # --- Main Processing Functions for Gradio ---
132
+
133
+ def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progress(track_tqdm=True)):
134
+ """Processes uploaded images for individual quality scores and displays them."""
135
+ if not uploaded_file_list:
136
+ return pd.DataFrame(), "Please upload images first.", "IS: N/A (Not Implemented)", "FID: N/A (Use FID Tab)"
137
+
138
+ # Limit number of images
139
+ if len(uploaded_file_list) > MAX_IMAGES_PER_BATCH:
140
+ status_message = f"Too many images ({len(uploaded_file_list)}). Processing first {MAX_IMAGES_PER_BATCH} images."
141
+ uploaded_file_list = uploaded_file_list[:MAX_IMAGES_PER_BATCH]
142
+ else:
143
+ status_message = f"Processing {len(uploaded_file_list)} images..."
144
+
145
+ progress(0, desc=status_message)
146
+
147
+ results_data = []
148
+ # Temporary directory for this batch if needed by some metric that takes a folder path
149
+ # batch_temp_dir = tempfile.mkdtemp(prefix="eval_batch_")
150
+
151
+ for i, file_obj in enumerate(uploaded_file_list):
152
+ try:
153
+ # file_obj for gr.Files is a tempfile._TemporaryFileWrapper object
154
+ file_path = file_obj.name
155
+ base_filename = os.path.basename(file_path)
156
+
157
+ img_pil_rgb = Image.open(file_path).convert("RGB")
158
+
159
+ # 1. For PIQ BRISQUE (needs tensor)
160
+ img_tensor_chw_01 = pil_to_tensor_chw_01(img_pil_rgb)
161
+ brisque_val = get_brisque_score(img_tensor_chw_01)
162
+
163
+ # 2. For IQA-PyTorch NIQE & MUSIQ (needs PIL image)
164
+ niqe_val = get_niqe_score(img_pil_rgb)
165
+ musiq_nr_val = get_musiq_nr_score(img_pil_rgb)
166
+
167
+ # 3. Thumbnail for display
168
+ thumbnail_b64 = create_thumbnail_base64(img_pil_rgb)
169
+ preview_html = f'<img src="{thumbnail_b64}" alt="{base_filename}">'
170
+
171
+ results_data.append({
172
+ "Preview": preview_html,
173
+ "Filename": base_filename,
174
+ "BRISQUE (PIQ) (↓)": brisque_val,
175
+ "NIQE (IQA-PyTorch) (↓)": niqe_val,
176
+ "MUSIQ-NR (IQA-PyTorch) (↑)": musiq_nr_val,
177
+ })
178
+ except Exception as e:
179
+ try: base_filename = os.path.basename(file_obj.name if hasattr(file_obj, 'name') else str(file_obj))
180
+ except: base_filename = "Unknown File"
181
+ results_data.append({
182
+ "Preview": "Error processing", "Filename": base_filename,
183
+ "BRISQUE (PIQ) (↓)": f"Load Err",
184
+ "NIQE (IQA-PyTorch) (↓)": "N/A",
185
+ "MUSIQ-NR (IQA-PyTorch) (↑)": "N/A",
186
+ })
187
+ progress((i + 1) / len(uploaded_file_list), desc=f"Processing {base_filename}")
188
+
189
+ df_results = pd.DataFrame(results_data)
190
+ status_message += f"\nPer-image metrics calculated for {len(results_data)} images."
191
+
192
+ # Batch metrics info (IS not implemented, FID separate)
193
+ is_text = "IS (PIQ): Not implemented in this version."
194
+ fid_text_batch_info = "FID (PIQ): Use the 'Calculate FID (Set vs Set)' tab for FID scores."
195
+
196
+ # Cleanup temp dir if created
197
+ # if os.path.exists(batch_temp_dir): shutil.rmtree(batch_temp_dir)
198
+
199
+ return df_results, status_message, is_text, fid_text_batch_info
200
+
201
+
202
+ def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progress(track_tqdm=True)):
203
+ """Handles FID calculation between two sets of uploaded images."""
204
+ if not set1_file_list or not set2_file_list:
205
+ return "Please upload files for both Set 1 and Set 2."
206
+
207
+ # Create temporary directories for Set 1 and Set 2
208
+ # Suffix helps identify user folders if many users hit it, though Gradio handles sessions.
209
+ # Prefix is better for mkdtemp.
210
+ set1_dir = tempfile.mkdtemp(prefix="fid_set1_")
211
+ set2_dir = tempfile.mkdtemp(prefix="fid_set2_")
212
+
213
+ fid_result_text = "Starting FID calculation..."
214
+ progress(0.1, desc="Preparing image sets for FID...")
215
+
216
+ try:
217
+ # Copy uploaded files to these temporary directories
218
+ for i, file_obj in enumerate(set1_file_list):
219
+ shutil.copy(file_obj.name, os.path.join(set1_dir, os.path.basename(file_obj.name)))
220
+ progress(0.1 + 0.2 * (i / len(set1_file_list)), desc=f"Copying Set 1: {os.path.basename(file_obj.name)}")
221
+
222
+ for i, file_obj in enumerate(set2_file_list):
223
+ shutil.copy(file_obj.name, os.path.join(set2_dir, os.path.basename(file_obj.name)))
224
+ progress(0.3 + 0.2 * (i / len(set2_file_list)), desc=f"Copying Set 2: {os.path.basename(file_obj.name)}")
225
+
226
+ num_set1 = len(os.listdir(set1_dir))
227
+ num_set2 = len(os.listdir(set2_dir))
228
+
229
+ if num_set1 == 0 or num_set2 == 0:
230
+ return f"FID Error: One or both sets are empty after copying. Set 1: {num_set1}, Set 2: {num_set2}."
231
+
232
+ progress(0.5, desc=f"Calculating FID between Set 1 ({num_set1} images) and Set 2 ({num_set2} images)...")
233
+ fid_score = get_fid_score_piq_folders(set1_dir, set2_dir)
234
+ progress(1, desc="FID calculation complete.")
235
+ fid_result_text = f"FID (PIQ) between Set 1 ({num_set1} images) and Set 2 ({num_set2} images): {fid_score}"
236
+
237
+ except Exception as e:
238
+ fid_result_text = f"Error during FID preparation or calculation: {str(e)}"
239
+ finally:
240
+ # Cleanup temporary directories
241
+ if os.path.exists(set1_dir): shutil.rmtree(set1_dir)
242
+ if os.path.exists(set2_dir): shutil.rmtree(set2_dir)
243
+
244
+ return fid_result_text
245
+
246
+ # --- Gradio UI Definition ---
247
+ css_custom = """
248
+ table {font-size: 0.8em !important; width: 100% !important;}
249
+ th, td {padding: 4px !important; text-align: left !important;}
250
+ img {max-width: 64px !important; max-height: 64px !important; object-fit: contain;}
251
+ """
252
+ with gr.Blocks(theme=gr.themes.Soft(), css=css_custom) as demo:
253
+ gr.Markdown(f"""
254
+ # Image Generation Model Evaluation Tool
255
+ **Objective:** Automated evaluation and comparison of image quality from different model versions.
256
+ Utilizes `PIQ` and `IQA-PyTorch` libraries. Runs on **{DEVICE}**.
257
+ (↓) means lower is better, (↑) means higher is better.
258
+ """)
259
+
260
+ with gr.Tabs():
261
+ with gr.TabItem("Per-Image Quality Evaluation"):
262
+ gr.Markdown(f"Upload a batch of images (up to **{MAX_IMAGES_PER_BATCH}**) to get individual quality scores. Images are processed in the browser's session.")
263
+
264
+ # Using gr.Files which allows multiple uploads and returns a list of TemporaryFileWrapper objects
265
+ image_upload_input = gr.Files(
266
+ label=f"Upload Images (max {MAX_IMAGES_PER_BATCH}, .png, .jpg, .jpeg, .bmp, .webp)",
267
+ file_count="multiple",
268
+ type="filepath" # Provides path to temp file
269
+ )
270
+
271
+ evaluate_button_main = gr.Button("🖼️ Evaluate Uploaded Images", variant="primary")
272
+
273
+ gr.Markdown("---")
274
+ status_output_main = gr.Textbox(label="📊 Evaluation Status", interactive=False, lines=2)
275
+ # batch_is_output = gr.Textbox(label="Overall Inception Score (IS) for Batch", interactive=False, lines=1) # IS deferred
276
+ # batch_fid_output_info = gr.Textbox(label="Overall Fréchet Inception Distance (FID) for Batch", interactive=False, lines=1) # FID separate
277
+
278
+ gr.Markdown("### 🖼️ Per-Image Evaluation Results")
279
+ gr.Markdown("Click column headers to sort. Previews are thumbnails.")
280
+ results_table_output = gr.DataFrame(
281
+ headers=["Preview", "Filename", "BRISQUE (PIQ) (↓)", "NIQE (IQA-PyTorch) (↓)", "MUSIQ-NR (IQA-PyTorch) (↑)"],
282
+ datatype=["html", "str", "number", "number", "number"],
283
+ interactive=False, wrap=True, max_rows=15, overflow_row_behaviour="paginate"
284
+ )
285
+
286
+ with gr.TabItem("↔️ Calculate FID (Set vs. Set)"):
287
+ gr.Markdown("""
288
+ Calculate Fréchet Inception Distance (FID) between two sets of images.
289
+ FID measures the similarity of two image distributions (e.g., generated vs. real, or version A vs. version B).
290
+ **Lower FID scores are better**, indicating more similarity.
291
+ """)
292
+ with gr.Row():
293
+ fid_set1_upload = gr.Files(label="Upload Images for Set 1 (.png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath")
294
+ fid_set2_upload = gr.Files(label="Upload Images for Set 2 (.png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath")
295
+
296
+ fid_calculate_button = gr.Button("🔗 Calculate FID between Set 1 and Set 2", variant="primary")
297
+ fid_result_output = gr.Textbox(label="📈 FID Result", interactive=False, lines=2)
298
+
299
+ # Wire components
300
+ evaluate_button_main.click(
301
+ fn=process_images_for_individual_scores,
302
+ inputs=[image_upload_input],
303
+ outputs=[results_table_output, status_output_main] #, batch_is_output, batch_fid_output_info]
304
+ )
305
+
306
+ fid_calculate_button.click(
307
+ fn=process_fid_for_two_sets,
308
+ inputs=[fid_set1_upload, fid_set2_upload],
309
+ outputs=[fid_result_output]
310
+ )
311
+
312
+ # --- For Hugging Face Spaces: requirements.txt ---
313
+ # Ensure this content is in your 'requirements.txt' file in the HF Space:
314
+ """
315
+ gradio
316
+ torch
317
+ torchvision
318
+ Pillow
319
+ numpy
320
+ piq>=0.8.0 # Specify version if known good, or just piq
321
+ iqa-pytorch>=0.2.1 # Specify version if known good
322
+ timm # A dependency for some iqa-pytorch models like MUSIQ
323
+ scikit-image # Often a transitive dependency, good to include
324
+ pandas
325
+ """
326
+
327
+ if __name__ == "__main__":
328
+ if piq is None or IQA is None:
329
+ print("\n\nWARNING: One or more core libraries (PIQ, IQA-PyTorch) are missing.")
330
+ print("Please install them by creating a 'requirements.txt' file with the content above and running: pip install -r requirements.txt\n\n")
331
+
332
+ demo.launch(debug=True) # Set debug=False for production