File size: 13,594 Bytes
8c2a1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdc8307
8c2a1e0
bdc8307
 
 
 
 
8c2a1e0
 
bdc8307
8c2a1e0
 
 
 
 
bdc8307
 
 
 
 
 
 
 
 
 
8c2a1e0
 
 
 
 
 
 
5157b8f
8c2a1e0
 
 
 
 
bdc8307
8c2a1e0
 
 
 
bdc8307
8c2a1e0
 
bdc8307
8c2a1e0
 
 
 
bdc8307
8c2a1e0
 
bdc8307
8c2a1e0
 
 
 
bdc8307
 
 
 
8c2a1e0
 
 
bdc8307
 
5157b8f
8c2a1e0
 
 
 
 
bdc8307
8c2a1e0
5157b8f
8c2a1e0
 
 
 
 
 
 
 
 
 
bdc8307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c2a1e0
bdc8307
8c2a1e0
 
5157b8f
8c2a1e0
 
 
 
 
 
 
 
 
 
 
bdc8307
8c2a1e0
 
 
 
 
 
 
 
 
bdc8307
 
 
8c2a1e0
 
 
 
 
 
 
 
 
bdc8307
8c2a1e0
 
 
 
bdc8307
8c2a1e0
 
bdc8307
8c2a1e0
 
 
 
 
5157b8f
8c2a1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdc8307
 
8c2a1e0
 
 
 
bdc8307
8c2a1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdc8307
8c2a1e0
 
 
 
bdc8307
 
8c2a1e0
 
 
 
 
 
bdc8307
 
5157b8f
 
bdc8307
8c2a1e0
 
 
 
 
bdc8307
8c2a1e0
 
bdc8307
 
8c2a1e0
 
 
bdc8307
 
8c2a1e0
bdc8307
 
8c2a1e0
 
 
 
 
 
5157b8f
bdc8307
5157b8f
 
8c2a1e0
bdc8307
8c2a1e0
 
 
bdc8307
 
5157b8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import gradio as gr
import pandas as pd
import torch
import torchvision.transforms as T
from PIL import Image
import numpy as np
import io
import base64
import os
import shutil
import tempfile

# PIQ imports
try:
    import piq
except ImportError:
    print("Warning: PIQ library not found. Some metrics (BRISQUE, FID) will be unavailable.")
    piq = None

# IQA-PyTorch imports
try:
    # This import needs to succeed for NIQE and MUSIQ
    from iqa_pytorch import IQA
except ImportError as e:
    print(f"ERROR: IQA-PyTorch library import failed: {e}. Some metrics (NIQE, MUSIQ-NR) will be unavailable. Check installation and dependencies (like kornia).")
    IQA = None
except Exception as e:
    print(f"ERROR: An unexpected error occurred during IQA-PyTorch import: {e}")
    IQA = None


# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_IMAGES_PER_BATCH = 100
THUMBNAIL_SIZE = (64, 64) # (width, height) for preview

# --- Metric Normalization Parameters (Approximate typical ranges) ---
# For "lower is better" metrics, score is (max_val - current_val) / (max_val - min_val)
# For "higher is better" metrics, score is (current_val - min_val) / (max_val - min_val)
# These are heuristics and can be adjusted.
METRIC_RANGES = {
    "brisque": {"min": 0, "max": 120, "lower_is_better": True}, # Typical BRISQUE range
    "niqe": {"min": 0, "max": 12, "lower_is_better": True},    # Typical NIQE range
    "musiq_nr": {"min": 10, "max": 90, "lower_is_better": False} # Example MUSIQ range
}

# --- Metric Functions ---

def get_brisque_score(img_tensor_chw_01):
    if piq is None: return "N/A (PIQ missing)"
    try:
        if img_tensor_chw_01.ndim == 3:
            img_tensor_bchw_01 = img_tensor_chw_01.unsqueeze(0)
        else:
            img_tensor_bchw_01 = img_tensor_chw_01
        if img_tensor_bchw_01.shape[1] == 1:
            img_tensor_bchw_01 = img_tensor_bchw_01.repeat(1, 3, 1, 1)
        brisque_loss = piq.brisque(img_tensor_bchw_01.to(DEVICE), data_range=1.)
        return round(brisque_loss.item(), 3)
    except Exception: return "Error"

def get_niqe_score(img_pil_rgb):
    if IQA is None: return "N/A (IQA missing)"
    try:
        niqe_metric = IQA(libs='NIQE-PyTorch', device=DEVICE)
        score = niqe_metric(img_pil_rgb)
        return round(score.item(), 3)
    except Exception: return "Error"

def get_musiq_nr_score(img_pil_rgb):
    if IQA is None: return "N/A (IQA missing)"
    try:
        musiq_metric = IQA(libs='MUSIQ-L2N-lessons', device=DEVICE) # Example, could be other MUSIQ variants
        score = musiq_metric(img_pil_rgb)
        return round(score.item(), 3)
    except Exception: return "Error"

def get_fid_score_piq_folders(path_to_set1_folder, path_to_set2_folder):
    if piq is None: return "N/A (PIQ missing)"
    try:
        set1_files = [os.path.join(path_to_set1_folder, f) for f in os.listdir(path_to_set1_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
        set2_files = [os.path.join(path_to_set2_folder, f) for f in os.listdir(path_to_set2_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
        if not set1_files or not set2_files: return "One or both sets have no valid image files."
        if len(set1_files) < 2 or len(set2_files) < 2: return f"FID needs at least 2 images per set. Found: Set1={len(set1_files)}, Set2={len(set2_files)}."
        fid_metric = piq.FID()
        set1_features = fid_metric.compute_feats(set1_files, device=DEVICE)
        set2_features = fid_metric.compute_feats(set2_files, device=DEVICE)
        if set1_features is None or set2_features is None: return "Could not extract features for one or both sets."
        if set1_features.dim() == 0 or set2_features.dim() == 0 or set1_features.numel() == 0 or set2_features.numel() == 0: return "Feature extraction resulted in empty tensors."
        fid_value = fid_metric(set1_features, set2_features)
        return round(fid_value.item(), 3)
    except Exception as e:
        print(f"FID calculation error: {e}")
        return f"FID Error: {str(e)[:100]}"

# --- Helper & Final Score Calculation ---
def pil_to_tensor_chw_01(img_pil_rgb):
    transform = T.Compose([T.ToTensor()])
    return transform(img_pil_rgb)

def create_thumbnail_base64(img_pil_rgb, size=THUMBNAIL_SIZE):
    img_copy = img_pil_rgb.copy()
    img_copy.thumbnail(size)
    buffered = io.BytesIO()
    img_copy.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return f"data:image/png;base64,{img_str}"

def calculate_final_score(brisque_val, niqe_val, musiq_nr_val):
    normalized_scores = []
    
    # BRISQUE
    if isinstance(brisque_val, (float, int)):
        cfg = METRIC_RANGES["brisque"]
        val = max(cfg["min"], min(cfg["max"], brisque_val)) # Clip
        norm_score = (cfg["max"] - val) / (cfg["max"] - cfg["min"]) if cfg["lower_is_better"] else (val - cfg["min"]) / (cfg["max"] - cfg["min"])
        normalized_scores.append(norm_score)

    # NIQE
    if isinstance(niqe_val, (float, int)):
        cfg = METRIC_RANGES["niqe"]
        val = max(cfg["min"], min(cfg["max"], niqe_val)) # Clip
        norm_score = (cfg["max"] - val) / (cfg["max"] - cfg["min"]) if cfg["lower_is_better"] else (val - cfg["min"]) / (cfg["max"] - cfg["min"])
        normalized_scores.append(norm_score)

    # MUSIQ-NR
    if isinstance(musiq_nr_val, (float, int)):
        cfg = METRIC_RANGES["musiq_nr"]
        val = max(cfg["min"], min(cfg["max"], musiq_nr_val)) # Clip
        norm_score = (cfg["max"] - val) / (cfg["max"] - cfg["min"]) if cfg["lower_is_better"] else (val - cfg["min"]) / (cfg["max"] - cfg["min"])
        normalized_scores.append(norm_score)

    if not normalized_scores:
        return "N/A"
    
    # Average of normalized scores, then scale to 0-10
    final_score_0_10 = (sum(normalized_scores) / len(normalized_scores)) * 10.0
    return round(final_score_0_10, 4)


# --- Main Processing Functions for Gradio ---
def process_images_for_individual_scores(uploaded_file_list, progress=gr.Progress(track_tqdm=True)):
    if not uploaded_file_list:
        return pd.DataFrame(), "Please upload images first."

    if len(uploaded_file_list) > MAX_IMAGES_PER_BATCH:
         status_message = f"Too many images ({len(uploaded_file_list)}). Processing first {MAX_IMAGES_PER_BATCH} images."
         uploaded_file_list = uploaded_file_list[:MAX_IMAGES_PER_BATCH]
    else:
        status_message = f"Processing {len(uploaded_file_list)} images..."
    
    progress(0, desc=status_message)
    results_data = []

    for i, file_obj in enumerate(uploaded_file_list):
        base_filename = "Unknown File"
        try:
            file_path = file_obj.name
            base_filename = os.path.basename(file_path)
            img_pil_rgb = Image.open(file_path).convert("RGB")

            img_tensor_chw_01 = pil_to_tensor_chw_01(img_pil_rgb)
            brisque_val = get_brisque_score(img_tensor_chw_01)
            niqe_val = get_niqe_score(img_pil_rgb)
            musiq_nr_val = get_musiq_nr_score(img_pil_rgb)
            
            final_score = calculate_final_score(brisque_val, niqe_val, musiq_nr_val)
            
            thumbnail_b64 = create_thumbnail_base64(img_pil_rgb)
            preview_html = f'<img src="{thumbnail_b64}" alt="{base_filename}">'

            results_data.append({
                "Preview": preview_html,
                "Filename": base_filename,
                "BRISQUE (PIQ) (↓)": brisque_val,
                "NIQE (IQA-PyTorch) (↓)": niqe_val,
                "MUSIQ-NR (IQA-PyTorch) (↑)": musiq_nr_val,
                "Final Score (0-10) (↑)": final_score,
            })
        except Exception as e:
            results_data.append({
                "Preview": "Error processing", "Filename": base_filename,
                "BRISQUE (PIQ) (↓)": f"Load Err: {str(e)[:30]}",
                "NIQE (IQA-PyTorch) (↓)": "N/A",
                "MUSIQ-NR (IQA-PyTorch) (↑)": "N/A",
                "Final Score (0-10) (↑)": "N/A",
            })
        progress((i + 1) / len(uploaded_file_list), desc=f"Processing {base_filename}")

    df_results = pd.DataFrame(results_data)
    status_message += f"\nPer-image metrics calculated for {len(results_data)} images."
    return df_results, status_message

def process_fid_for_two_sets(set1_file_list, set2_file_list, progress=gr.Progress(track_tqdm=True)):
    if not set1_file_list or not set2_file_list:
        return "Please upload files for both Set 1 and Set 2."
    set1_dir = tempfile.mkdtemp(prefix="fid_set1_")
    set2_dir = tempfile.mkdtemp(prefix="fid_set2_")
    fid_result_text = "Starting FID calculation..."
    progress(0.1, desc="Preparing image sets for FID...")
    try:
        for i, file_obj in enumerate(set1_file_list):
            shutil.copy(file_obj.name, os.path.join(set1_dir, os.path.basename(file_obj.name)))
            progress(0.1 + 0.2 * (i / len(set1_file_list)), desc=f"Copying Set 1: {os.path.basename(file_obj.name)}")
        for i, file_obj in enumerate(set2_file_list):
            shutil.copy(file_obj.name, os.path.join(set2_dir, os.path.basename(file_obj.name)))
            progress(0.3 + 0.2 * (i / len(set2_file_list)), desc=f"Copying Set 2: {os.path.basename(file_obj.name)}")
        num_set1 = len(os.listdir(set1_dir)); num_set2 = len(os.listdir(set2_dir))
        if num_set1 == 0 or num_set2 == 0: return f"FID Error: One or both sets are empty after copying. Set 1: {num_set1}, Set 2: {num_set2}."
        progress(0.5, desc=f"Calculating FID between Set 1 ({num_set1} images) and Set 2 ({num_set2} images)...")
        fid_score = get_fid_score_piq_folders(set1_dir, set2_dir)
        progress(1, desc="FID calculation complete.")
        fid_result_text = f"FID (PIQ) between Set 1 ({num_set1} images) and Set 2 ({num_set2} images): {fid_score}"
    except Exception as e: fid_result_text = f"Error during FID preparation or calculation: {str(e)}"
    finally:
        if os.path.exists(set1_dir): shutil.rmtree(set1_dir)
        if os.path.exists(set2_dir): shutil.rmtree(set2_dir)
    return fid_result_text

# --- Gradio UI Definition ---
css_custom = """
table {font-size: 0.8em !important; width: 100% !important;} 
th, td {padding: 4px !important; text-align: left !important;}
img {max-width: 64px !important; max-height: 64px !important; object-fit: contain;}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css_custom) as demo:
    gr.Markdown(f"""
    # Image Generation Model Evaluation Tool
    **Objective:** Automated evaluation and comparison of image quality from different model versions.
    Utilizes `PIQ` and `IQA-PyTorch` libraries. Runs on **{DEVICE}**.
    (↓) means lower is better, (↑) means higher is better.
    Final Score is a heuristic combination of available metrics (0-10, higher is better).
    """)

    with gr.Tabs():
        with gr.TabItem("Per-Image Quality Evaluation"):
            gr.Markdown(f"Upload a batch of images (up to **{MAX_IMAGES_PER_BATCH}**) to get individual quality scores.")
            image_upload_input = gr.Files(label=f"Upload Images (max {MAX_IMAGES_PER_BATCH}, .png, .jpg, .jpeg, .bmp, .webp)", file_count="multiple", type="filepath")
            evaluate_button_main = gr.Button("πŸ–ΌοΈ Evaluate Uploaded Images", variant="primary")
            gr.Markdown("---")
            status_output_main = gr.Textbox(label="πŸ“Š Evaluation Status", interactive=False, lines=2)
            gr.Markdown("### πŸ–ΌοΈ Per-Image Evaluation Results")
            gr.Markdown("Click column headers to sort. Previews are thumbnails.")
            results_table_output = gr.DataFrame(
                headers=["Preview", "Filename", "BRISQUE (PIQ) (↓)", "NIQE (IQA-PyTorch) (↓)", "MUSIQ-NR (IQA-PyTorch) (↑)", "Final Score (0-10) (↑)"],
                datatype=["html", "str", "number", "number", "number", "number"], # Added "number" for Final Score
                interactive=False,
                wrap=True,
                row_count=(15, "paginate")
            )

        with gr.TabItem("↔️ Calculate FID (Set vs. Set)"):
            gr.Markdown("""
            Calculate FrΓ©chet Inception Distance (FID) between two sets of images. 
            FID measures the similarity of two image distributions. **Lower FID scores are better**.
            """)
            with gr.Row():
                fid_set1_upload = gr.Files(label="Upload Images for Set 1", file_count="multiple", type="filepath")
                fid_set2_upload = gr.Files(label="Upload Images for Set 2", file_count="multiple", type="filepath")
            fid_calculate_button = gr.Button("πŸ”— Calculate FID between Set 1 and Set 2", variant="primary")
            fid_result_output = gr.Textbox(label="πŸ“ˆ FID Result", interactive=False, lines=2)

    evaluate_button_main.click(fn=process_images_for_individual_scores, inputs=[image_upload_input], outputs=[results_table_output, status_output_main])
    fid_calculate_button.click(fn=process_fid_for_two_sets, inputs=[fid_set1_upload, fid_set2_upload], outputs=[fid_result_output])

# --- For Hugging Face Spaces ---
# Ensure 'requirements.txt' includes:
"""
gradio
torch
torchvision
Pillow
numpy
piq>=0.8.0
iqa-pytorch==0.1
timm
scikit-image
pandas
kornia 
"""

if __name__ == "__main__":
    if piq is None: print("\nWARNING: PIQ library is missing. pip install piq\n")
    if IQA is None: print("\nERROR: IQA-PyTorch library import failed. pip install iqa-pytorch==0.1 kornia\n")
    demo.launch(debug=True)