File size: 14,977 Bytes
8c2a1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a98506
8c2a1e0
 
 
 
 
 
 
 
 
6a98506
8c2a1e0
 
 
6a98506
 
 
 
8c2a1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import gradio as gr
import pandas as pd
import torch
import torchvision.transforms as T
from PIL import Image
import numpy as np
import io
import base64
import os
import shutil
import tempfile

# PIQ imports
try:
    import piq
except ImportError:
    print("Warning: PIQ library not found. Some metrics (BRISQUE, FID) will be unavailable.")
    piq = None

# IQA-PyTorch imports
try:
    from iqa_pytorch import IQA
    # Available models in IQA-PyTorch (examples for NR):
    # "MUSIQ-L2N-lessons", "MUSIQ-Koniq-NSR", "MUSIQ-SpAq-NSR"
    # "BRISQUE-PyTorch", "NIQE-PyTorch"
    # "NIMA-VGG16-estimate", "NIMA-MobileNet-estimate" (Aesthetic)
except ImportError:
    print("Warning: IQA-PyTorch library not found. Some metrics (NIQE, MUSIQ-NR) will be unavailable.")
    IQA = None

# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_IMAGES_PER_BATCH = 100
THUMBNAIL_SIZE = (64, 64) # (width, height) for preview

# --- Metric Functions ---

def get_brisque_score(img_tensor_chw_01):
    """Calculates BRISQUE score using PIQ. Expects a (C, H, W) tensor, range [0, 1]."""
    if piq is None: return "N/A (PIQ missing)"
    try:
        # Ensure tensor is (B, C, H, W) for piq.brisque
        if img_tensor_chw_01.ndim == 3:
            img_tensor_bchw_01 = img_tensor_chw_01.unsqueeze(0)
        else: # Already has batch dim or incorrect dims
            img_tensor_bchw_01 = img_tensor_chw_01

        # Ensure 3 channels if it's grayscale by repeating
        if img_tensor_bchw_01.shape[1] == 1:
            img_tensor_bchw_01 = img_tensor_bchw_01.repeat(1, 3, 1, 1)

        brisque_loss = piq.brisque(img_tensor_bchw_01.to(DEVICE), data_range=1.)
        return round(brisque_loss.item(), 3)
    except Exception as e:
        # print(f"BRISQUE Error: {e} for tensor shape {img_tensor_chw_01.shape}")
        return f"Error"


def get_niqe_score(img_pil_rgb):
    """Calculates NIQE score using IQA-PyTorch. Expects a PIL RGB image."""
    if IQA is None: return "N/A (IQA missing)"
    try:
        niqe_metric = IQA(libs='NIQE-PyTorch', device=DEVICE) # NIQE is No-Reference
        score = niqe_metric(img_pil_rgb)
        return round(score.item(), 3)
    except Exception as e:
        # print(f"NIQE Error: {e}")
        return f"Error"

def get_musiq_nr_score(img_pil_rgb):
    """Calculates No-Reference MUSIQ score using IQA-PyTorch. Expects a PIL RGB image."""
    if IQA is None: return "N/A (IQA missing)"
    try:
        # Using MUSIQ-L2N-lessons as an example NR model from IQA-PyTorch
        # Other options: "MUSIQ-Koniq-NSR", "MUSIQ-SpAq-NSR"
        musiq_metric = IQA(libs='MUSIQ-L2N-lessons', device=DEVICE)
        score = musiq_metric(img_pil_rgb)
        return round(score.item(), 3)
    except Exception as e:
        # print(f"MUSIQ-NR Error: {e}")
        return f"Error"


def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder):
    """Calculates FID between two folders of images using PIQ."""
    if piq is None: return "N/A (PIQ missing)"
    try:
        # List image files in folders
        set1_files = [os.path.join(path_to_set1_folder, f) for f in os.listdir(path_to_set1_folder)
                      if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
        set2_files = [os.path.join(path_to_set2_folder, f) for f in os.listdir(path_to_set2_folder)
                      if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]

        if not set1_files or not set2_files:
            return "One or both sets have no valid image files."
        if len(set1_files) < 2 or len(set2_files) < 2: # FID usually needs more, but PIQ might handle small N. Min 2 to compute stats.
             return f"FID needs at least 2 images per set. Found: Set1={len(set1_files)}, Set2={len(set2_files)}."

        fid_metric = piq.FID()
        # compute_feats expects a list of image paths
        set1_features = fid_metric.compute_feats(set1_files, device=DEVICE)
        set2_features = fid_metric.compute_feats(set2_files, device=DEVICE)

        if set1_features is None or set2_features is None:
             return "Could not extract features for one or both sets (check image validity and count)."
        if set1_features.dim() == 0 or set2_features.dim() == 0 or set1_features.numel() == 0 or set2_features.numel() == 0: # Handle empty tensors
             return "Feature extraction resulted in empty tensors."


        fid_value = fid_metric(set1_features, set2_features) # Pass tensors directly
        return round(fid_value.item(), 3)
    except Exception as e:
        print(f"FID calculation error: {e}")
        return f"FID Error: {str(e)[:100]}"

# --- Helper Functions ---
def pil_to_tensor_chw_01(img_pil_rgb):
    """Converts PIL RGB image to PyTorch CHW tensor [0,1]."""
    transform = T.Compose([T.ToTensor()]) # Converts PIL [0,255] to Tensor [0,1] C,H,W
    return transform(img_pil_rgb)

def create_thumbnail_base64(img_pil_rgb, size=THUMBNAIL_SIZE):
    """Creates a base64 encoded PNG thumbnail string from a PIL image."""
    img_copy = img_pil_rgb.copy()
    img_copy.thumbnail(size)
    buffered = io.BytesIO()
    img_copy.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return f"data:image/png;base64,{img_str}"

# --- Main Processing Functions for Gradio ---

def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progress(track_tqdm=True)):
    """Processes uploaded images for individual quality scores and displays them."""
    if not uploaded_file_list:
        return pd.DataFrame(), "Please upload images first.", "IS: N/A (Not Implemented)", "FID: N/A (Use FID Tab)"

    # Limit number of images
    if len(uploaded_file_list) > MAX_IMAGES_PER_BATCH:
         status_message = f"Too many images ({len(uploaded_file_list)}). Processing first {MAX_IMAGES_PER_BATCH} images."
         uploaded_file_list = uploaded_file_list[:MAX_IMAGES_PER_BATCH]
    else:
        status_message = f"Processing {len(uploaded_file_list)} images..."
    
    progress(0, desc=status_message)

    results_data = []
    # Temporary directory for this batch if needed by some metric that takes a folder path
    # batch_temp_dir = tempfile.mkdtemp(prefix="eval_batch_")

    for i, file_obj in enumerate(uploaded_file_list):
        try:
            # file_obj for gr.Files is a tempfile._TemporaryFileWrapper object
            file_path = file_obj.name
            base_filename = os.path.basename(file_path)

            img_pil_rgb = Image.open(file_path).convert("RGB")

            # 1. For PIQ BRISQUE (needs tensor)
            img_tensor_chw_01 = pil_to_tensor_chw_01(img_pil_rgb)
            brisque_val = get_brisque_score(img_tensor_chw_01)

            # 2. For IQA-PyTorch NIQE & MUSIQ (needs PIL image)
            niqe_val = get_niqe_score(img_pil_rgb)
            musiq_nr_val = get_musiq_nr_score(img_pil_rgb)

            # 3. Thumbnail for display
            thumbnail_b64 = create_thumbnail_base64(img_pil_rgb)
            preview_html = f'<img src="{thumbnail_b64}" alt="{base_filename}">'

            results_data.append({
                "Preview": preview_html,
                "Filename": base_filename,
                "BRISQUE (PIQ) (↓)": brisque_val,
                "NIQE (IQA-PyTorch) (↓)": niqe_val,
                "MUSIQ-NR (IQA-PyTorch) (↑)": musiq_nr_val,
            })
        except Exception as e:
            try: base_filename = os.path.basename(file_obj.name if hasattr(file_obj, 'name') else str(file_obj))
            except: base_filename = "Unknown File"
            results_data.append({
                "Preview": "Error processing", "Filename": base_filename,
                "BRISQUE (PIQ) (↓)": f"Load Err",
                "NIQE (IQA-PyTorch) (↓)": "N/A",
                "MUSIQ-NR (IQA-PyTorch) (↑)": "N/A",
            })
        progress((i + 1) / len(uploaded_file_list), desc=f"Processing {base_filename}")

    df_results = pd.DataFrame(results_data)
    status_message += f"\nPer-image metrics calculated for {len(results_data)} images."

    # Batch metrics info (IS not implemented, FID separate)
    is_text = "IS (PIQ): Not implemented in this version."
    fid_text_batch_info = "FID (PIQ): Use the 'Calculate FID (Set vs Set)' tab for FID scores."

    # Cleanup temp dir if created
    # if os.path.exists(batch_temp_dir): shutil.rmtree(batch_temp_dir)

    return df_results, status_message, is_text, fid_text_batch_info


def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progress(track_tqdm=True)):
    """Handles FID calculation between two sets of uploaded images."""
    if not set1_file_list or not set2_file_list:
        return "Please upload files for both Set 1 and Set 2."

    # Create temporary directories for Set 1 and Set 2
    # Suffix helps identify user folders if many users hit it, though Gradio handles sessions.
    # Prefix is better for mkdtemp.
    set1_dir = tempfile.mkdtemp(prefix="fid_set1_")
    set2_dir = tempfile.mkdtemp(prefix="fid_set2_")
    
    fid_result_text = "Starting FID calculation..."
    progress(0.1, desc="Preparing image sets for FID...")

    try:
        # Copy uploaded files to these temporary directories
        for i, file_obj in enumerate(set1_file_list):
            shutil.copy(file_obj.name, os.path.join(set1_dir, os.path.basename(file_obj.name)))
            progress(0.1 + 0.2 * (i / len(set1_file_list)), desc=f"Copying Set 1: {os.path.basename(file_obj.name)}")

        for i, file_obj in enumerate(set2_file_list):
            shutil.copy(file_obj.name, os.path.join(set2_dir, os.path.basename(file_obj.name)))
            progress(0.3 + 0.2 * (i / len(set2_file_list)), desc=f"Copying Set 2: {os.path.basename(file_obj.name)}")
        
        num_set1 = len(os.listdir(set1_dir))
        num_set2 = len(os.listdir(set2_dir))

        if num_set1 == 0 or num_set2 == 0:
             return f"FID Error: One or both sets are empty after copying. Set 1: {num_set1}, Set 2: {num_set2}."

        progress(0.5, desc=f"Calculating FID between Set 1 ({num_set1} images) and Set 2 ({num_set2} images)...")
        fid_score = get_fid_score_piq_folders(set1_dir, set2_dir)
        progress(1, desc="FID calculation complete.")
        fid_result_text = f"FID (PIQ) between Set 1 ({num_set1} images) and Set 2 ({num_set2} images): {fid_score}"

    except Exception as e:
        fid_result_text = f"Error during FID preparation or calculation: {str(e)}"
    finally:
        # Cleanup temporary directories
        if os.path.exists(set1_dir): shutil.rmtree(set1_dir)
        if os.path.exists(set2_dir): shutil.rmtree(set2_dir)

    return fid_result_text

# --- Gradio UI Definition ---
css_custom = """
table {font-size: 0.8em !important; width: 100% !important;} 
th, td {padding: 4px !important; text-align: left !important;}
img {max-width: 64px !important; max-height: 64px !important; object-fit: contain;}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css_custom) as demo:
    gr.Markdown(f"""
    # Image Generation Model Evaluation Tool
    **Objective:** Automated evaluation and comparison of image quality from different model versions.
    Utilizes `PIQ` and `IQA-PyTorch` libraries. Runs on **{DEVICE}**.
    (↓) means lower is better, (↑) means higher is better.
    """)

    with gr.Tabs():
        with gr.TabItem("Per-Image Quality Evaluation"):
            gr.Markdown(f"Upload a batch of images (up to **{MAX_IMAGES_PER_BATCH}**) to get individual quality scores. Images are processed in the browser's session.")
            
            image_upload_input = gr.Files(
                label=f"Upload Images (max {MAX_IMAGES_PER_BATCH}, .png, .jpg, .jpeg, .bmp, .webp)",
                file_count="multiple",
                type="filepath" 
            )
            
            evaluate_button_main = gr.Button("πŸ–ΌοΈ Evaluate Uploaded Images", variant="primary")
            
            gr.Markdown("---")
            status_output_main = gr.Textbox(label="πŸ“Š Evaluation Status", interactive=False, lines=2)

            gr.Markdown("### πŸ–ΌοΈ Per-Image Evaluation Results")
            gr.Markdown("Click column headers to sort. Previews are thumbnails.")
            # MODIFIED LINE BELOW:
            results_table_output = gr.DataFrame(
                headers=["Preview", "Filename", "BRISQUE (PIQ) (↓)", "NIQE (IQA-PyTorch) (↓)", "MUSIQ-NR (IQA-PyTorch) (↑)"],
                datatype=["html", "str", "number", "number", "number"],
                interactive=False, 
                wrap=True, 
                overflow_row_behaviour="paginate", # max_rows removed
                # height=450 # Optional: Set a fixed height in pixels if you want ~15 rows visible before scrolling within the component
            )

        with gr.TabItem("↔️ Calculate FID (Set vs. Set)"):
            gr.Markdown("""
            Calculate FrΓ©chet Inception Distance (FID) between two sets of images. 
            FID measures the similarity of two image distributions (e.g., generated vs. real, or version A vs. version B). 
            **Lower FID scores are better**, indicating more similarity.
            """)
            with gr.Row():
                fid_set1_upload = gr.Files(label="Upload Images for Set 1 (.png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath")
                fid_set2_upload = gr.Files(label="Upload Images for Set 2 (.png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath")
            
            fid_calculate_button = gr.Button("πŸ”— Calculate FID between Set 1 and Set 2", variant="primary")
            fid_result_output = gr.Textbox(label="πŸ“ˆ FID Result", interactive=False, lines=2)

    # Wire components
    evaluate_button_main.click(
        fn=process_images_for_individual_scores,
        inputs=[image_upload_input],
        outputs=[results_table_output, status_output_main] #, batch_is_output, batch_fid_output_info]
    )

    fid_calculate_button.click(
        fn=process_fid_for_two_sets,
        inputs=[fid_set1_upload, fid_set2_upload],
        outputs=[fid_result_output]
    )

# --- For Hugging Face Spaces: requirements.txt ---
# Ensure this content is in your 'requirements.txt' file in the HF Space:
"""
gradio
torch
torchvision
Pillow
numpy
piq>=0.8.0 # Specify version if known good, or just piq
iqa-pytorch>=0.2.1 # Specify version if known good
timm # A dependency for some iqa-pytorch models like MUSIQ
scikit-image # Often a transitive dependency, good to include
pandas
"""

if __name__ == "__main__":
    if piq is None or IQA is None:
        print("\n\nWARNING: One or more core libraries (PIQ, IQA-PyTorch) are missing.")
        print("Please install them by creating a 'requirements.txt' file with the content above and running: pip install -r requirements.txt\n\n")
    
    demo.launch(debug=True) # Set debug=False for production