|
import os |
|
import cv2 |
|
import json |
|
import torch |
|
import pyiqa |
|
import numpy as np |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from torchvision import transforms |
|
|
|
|
|
to_tensor = transforms.ToTensor() |
|
video_exts = ['.mp4', '.avi', '.mov', '.mkv'] |
|
fr_metrics = ['psnr', 'ssim', 'lpips', 'dists'] |
|
|
|
|
|
def is_video_file(filename): |
|
return any(filename.lower().endswith(ext) for ext in video_exts) |
|
|
|
def rgb_to_y(img): |
|
|
|
r, g, b = img[:, 0:1], img[:, 1:2], img[:, 2:3] |
|
y = 0.257 * r + 0.504 * g + 0.098 * b + 0.0625 |
|
return y |
|
|
|
def crop_border(img, crop): |
|
return img[:, :, crop:-crop, crop:-crop] |
|
|
|
def read_video_frames(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
frames = [] |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frames.append(to_tensor(Image.fromarray(rgb))) |
|
cap.release() |
|
return torch.stack(frames) |
|
|
|
|
|
def read_image_folder(folder_path): |
|
image_files = sorted([ |
|
os.path.join(folder_path, f) for f in os.listdir(folder_path) |
|
if f.lower().endswith(('.png', '.jpg', '.jpeg')) |
|
]) |
|
frames = [to_tensor(Image.open(p).convert("RGB")) for p in image_files] |
|
return torch.stack(frames) |
|
|
|
|
|
def load_sequence(path): |
|
if os.path.isdir(path): |
|
return read_image_folder(path) |
|
elif os.path.isfile(path): |
|
if is_video_file(path): |
|
return read_video_frames(path) |
|
elif path.lower().endswith(('.png', '.jpg', '.jpeg')): |
|
|
|
img = to_tensor(Image.open(path).convert("RGB")) |
|
return img.unsqueeze(0) |
|
raise ValueError(f"Unsupported input: {path}") |
|
|
|
|
|
def crop_img_center(img, target_h, target_w): |
|
_, h, w = img.shape |
|
top = max((h - target_h) // 2, 0) |
|
left = max((w - target_w) // 2, 0) |
|
return img[:, top:top+target_h, left:left+target_w] |
|
|
|
def crop_img_top_left(img, target_h, target_w): |
|
|
|
return img[:, :target_h, :target_w] |
|
|
|
def match_resolution(gt_frames, pred_frames, is_center=False, name=None): |
|
t = min(gt_frames.shape[0], pred_frames.shape[0]) |
|
gt_frames = gt_frames[:t] |
|
pred_frames = pred_frames[:t] |
|
_, _, h_g, w_g = gt_frames.shape |
|
_, _, h_p, w_p = pred_frames.shape |
|
|
|
target_h = min(h_g, h_p) |
|
target_w = min(w_g, w_p) |
|
|
|
if (h_g != h_p or w_g != w_p) and name: |
|
if is_center: |
|
print(f"[{name}] Resolution mismatch detected: GT is ({h_g}, {w_g}), Pred is ({h_p}, {w_p}). Both GT and Pred were center cropped to ({target_h}, {target_w}).") |
|
else: |
|
print(f"[{name}] Resolution mismatch detected: GT is ({h_g}, {w_g}), Pred is ({h_p}, {w_p}). Both GT and Pred were top-left cropped to ({target_h}, {target_w}).") |
|
|
|
if is_center: |
|
gt_frames = torch.stack([crop_img_center(f, target_h, target_w) for f in gt_frames]) |
|
pred_frames = torch.stack([crop_img_center(f, target_h, target_w) for f in pred_frames]) |
|
else: |
|
gt_frames = torch.stack([crop_img_top_left(f, target_h, target_w) for f in gt_frames]) |
|
pred_frames = torch.stack([crop_img_top_left(f, target_h, target_w) for f in pred_frames]) |
|
|
|
return gt_frames, pred_frames |
|
|
|
|
|
def init_models(metrics, device): |
|
models = {} |
|
for name in metrics: |
|
try: |
|
models[name] = pyiqa.create_metric(name).to(device).eval() |
|
except Exception as e: |
|
print(f"Failed to initialize metric '{name}': {e}") |
|
return models |
|
|
|
def compute_metrics(pred_frames, gt_frames, models, device, batch_mode, crop, test_y_channel): |
|
if batch_mode: |
|
pred_batch = pred_frames.to(device) |
|
gt_batch = gt_frames.to(device) |
|
|
|
results = {} |
|
for name, model in models.items(): |
|
if name in fr_metrics: |
|
pred_eval = pred_batch |
|
gt_eval = gt_batch |
|
if crop > 0: |
|
pred_eval = crop_border(pred_eval, crop) |
|
gt_eval = crop_border(gt_eval, crop) |
|
if test_y_channel: |
|
pred_eval = rgb_to_y(pred_eval) |
|
gt_eval = rgb_to_y(gt_eval) |
|
values = model(pred_eval, gt_eval) |
|
else: |
|
values = model(pred_batch) |
|
results[name] = round(values.mean().item(), 4) |
|
return results |
|
|
|
else: |
|
results = {name: [] for name in models} |
|
for pred, gt in zip(pred_frames, gt_frames): |
|
pred = pred.unsqueeze(0).to(device) |
|
gt = gt.unsqueeze(0).to(device) |
|
|
|
for name, model in models.items(): |
|
if name in fr_metrics: |
|
pred_eval = pred |
|
gt_eval = gt |
|
if crop > 0: |
|
pred_eval = crop_border(pred_eval, crop) |
|
gt_eval = crop_border(gt_eval, crop) |
|
if test_y_channel: |
|
pred_eval = rgb_to_y(pred_eval) |
|
gt_eval = rgb_to_y(gt_eval) |
|
value = model(pred_eval, gt_eval).item() |
|
else: |
|
value = model(pred).item() |
|
results[name].append(value) |
|
|
|
return {k: round(np.mean(v), 4) for k, v in results.items()} |
|
|
|
|
|
def process(gt_root, pred_root, out_path, metrics, batch_mode, crop, test_y_channel, is_center): |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"Using device: {device}") |
|
models = init_models(metrics, device) |
|
|
|
has_gt = bool(gt_root and os.path.exists(gt_root)) |
|
|
|
if has_gt: |
|
gt_files = {os.path.splitext(f)[0]: os.path.join(gt_root, f) for f in os.listdir(gt_root)} |
|
pred_files = {os.path.splitext(f)[0]: os.path.join(pred_root, f) for f in os.listdir(pred_root)} |
|
|
|
pred_names = sorted(pred_files.keys()) |
|
results = {} |
|
aggregate = {metric: [] for metric in metrics} |
|
|
|
for name in tqdm(pred_names, desc="Evaluating"): |
|
|
|
|
|
name_hr = name |
|
if has_gt and name_hr not in gt_files: |
|
print(f"Skipping {name_hr}: no matching GT file.") |
|
continue |
|
|
|
pred_path = pred_files[name] |
|
gt_path = gt_files[name_hr] if has_gt else None |
|
|
|
try: |
|
pred_frames = load_sequence(pred_path) |
|
|
|
if has_gt: |
|
gt_frames = load_sequence(gt_path) |
|
gt_frames, pred_frames = match_resolution(gt_frames, pred_frames, is_center=is_center, name=name) |
|
scores = compute_metrics(pred_frames, gt_frames, models, device, batch_mode, crop, test_y_channel) |
|
else: |
|
nr_models = {k: v for k, v in models.items() if k not in fr_metrics} |
|
if not nr_models: |
|
print(f"Skipping {name}: GT is not provided and no NR-IQA metrics found.") |
|
continue |
|
dummy_gt = pred_frames |
|
scores = compute_metrics(pred_frames, dummy_gt, nr_models, device, batch_mode, crop, test_y_channel) |
|
|
|
results[name] = scores |
|
for k in scores: |
|
aggregate[k].append(scores[k]) |
|
except Exception as e: |
|
print(f"Error processing {name}: {e}") |
|
|
|
print("\nPer-sample Results:") |
|
for name in sorted(results): |
|
print(f"{name}: " + ", ".join(f"{k}={v:.4f}" for k, v in results[name].items())) |
|
|
|
print("\nOverall Average Results:") |
|
count = len(results) |
|
if count > 0: |
|
overall_avg = {k: round(np.mean(v), 4) for k, v in aggregate.items()} |
|
for k, v in overall_avg.items(): |
|
print(f"{k.upper()}: {v:.4f}") |
|
else: |
|
overall_avg = {} |
|
print("No valid samples were processed.") |
|
|
|
print(f"\nProcessed {count} samples.") |
|
|
|
output = { |
|
"per_sample": results, |
|
"average": overall_avg, |
|
"count": count |
|
} |
|
|
|
os.makedirs(out_path, exist_ok=True) |
|
out_name = 'metrics_' |
|
for metric in metrics: |
|
out_name += f"{metric}_" |
|
out_name = out_name.rstrip('_') + '.json' |
|
out_path = os.path.join(out_path, out_name) |
|
|
|
with open(out_path, 'w') as f: |
|
json.dump(output, f, indent=2) |
|
|
|
print(f"Results saved to: {out_path}") |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--gt', type=str, default='', help='Path to GT folder (optional for NR-IQA)') |
|
parser.add_argument('--pred', type=str, required=True, help='Path to predicted results folder') |
|
parser.add_argument('--out', type=str, default='', help='Path to save JSON output (as directory)') |
|
parser.add_argument('--metrics', type=str, default='psnr,ssim,clipiqa', |
|
help='Comma-separated list of metrics: psnr,ssim,clipiqa,lpips,...') |
|
parser.add_argument('--batch_mode', action='store_true', help='Use batch mode for metrics computation') |
|
parser.add_argument('--crop', type=int, default=0, help='Crop border size for PSNR/SSIM') |
|
parser.add_argument('--test_y_channel', action='store_true', help='Use Y channel for PSNR/SSIM') |
|
parser.add_argument('--is_center', action='store_true', help='Use center crop for PSNR/SSIM') |
|
|
|
args = parser.parse_args() |
|
|
|
if args.out == '': |
|
out = args.pred |
|
else: |
|
out = args.out |
|
metric_list = [m.strip().lower() for m in args.metrics.split(',')] |
|
process(args.gt, args.pred, out, metric_list, args.batch_mode, args.crop, args.test_y_channel, args.is_center) |
|
|