|
""" |
|
Visualization Utilities |
|
|
|
This module provides comprehensive visualization tools for segmentation results, |
|
attention maps, and experiment comparisons in few-shot and zero-shot learning. |
|
""" |
|
|
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
from matplotlib.colors import ListedColormap |
|
import seaborn as sns |
|
from typing import Dict, List, Tuple, Optional, Union |
|
import cv2 |
|
from PIL import Image |
|
import os |
|
|
|
|
|
class SegmentationVisualizer: |
|
"""Visualization tools for segmentation results.""" |
|
|
|
def __init__(self, figsize: Tuple[int, int] = (15, 10)): |
|
self.figsize = figsize |
|
|
|
|
|
self.class_colors = { |
|
'building': [1.0, 0.0, 0.0], |
|
'road': [0.0, 1.0, 0.0], |
|
'vegetation': [0.0, 0.0, 1.0], |
|
'water': [1.0, 1.0, 0.0], |
|
'shirt': [1.0, 0.5, 0.0], |
|
'pants': [0.5, 0.0, 1.0], |
|
'dress': [0.0, 1.0, 1.0], |
|
'shoes': [1.0, 0.0, 1.0], |
|
'robot': [0.5, 0.5, 0.5], |
|
'tool': [0.8, 0.4, 0.2], |
|
'safety': [0.2, 0.8, 0.2] |
|
} |
|
|
|
def visualize_segmentation( |
|
self, |
|
image: torch.Tensor, |
|
predictions: Dict[str, torch.Tensor], |
|
ground_truth: Optional[Dict[str, torch.Tensor]] = None, |
|
title: str = "Segmentation Results" |
|
) -> plt.Figure: |
|
"""Visualize segmentation results with optional ground truth comparison.""" |
|
num_classes = len(predictions) |
|
has_gt = ground_truth is not None |
|
|
|
|
|
if has_gt: |
|
cols = 3 |
|
rows = max(2, num_classes) |
|
else: |
|
cols = 2 |
|
rows = max(1, num_classes) |
|
|
|
fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4)) |
|
if rows == 1: |
|
axes = axes.reshape(1, -1) |
|
|
|
|
|
image_np = image.permute(1, 2, 0).cpu().numpy() |
|
|
|
if image_np.min() < 0 or image_np.max() > 1: |
|
image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min()) |
|
|
|
axes[0, 0].imshow(image_np) |
|
axes[0, 0].set_title("Original Image") |
|
axes[0, 0].axis('off') |
|
|
|
|
|
if cols > 1: |
|
combined_pred = self.create_combined_mask(predictions) |
|
axes[0, 1].imshow(image_np) |
|
axes[0, 1].imshow(combined_pred, alpha=0.6, cmap='tab10') |
|
axes[0, 1].set_title("Combined Predictions") |
|
axes[0, 1].axis('off') |
|
|
|
|
|
if has_gt and cols > 2: |
|
combined_gt = self.create_combined_mask(ground_truth) |
|
axes[0, 2].imshow(image_np) |
|
axes[0, 2].imshow(combined_gt, alpha=0.6, cmap='tab10') |
|
axes[0, 2].set_title("Ground Truth") |
|
axes[0, 2].axis('off') |
|
|
|
|
|
for i, (class_name, pred_mask) in enumerate(predictions.items()): |
|
row = i + 1 if has_gt else i |
|
col_offset = 0 |
|
|
|
|
|
pred_np = pred_mask.cpu().numpy() |
|
axes[row, col_offset].imshow(pred_np, cmap='gray') |
|
axes[row, col_offset].set_title(f"Prediction: {class_name}") |
|
axes[row, col_offset].axis('off') |
|
|
|
|
|
col_offset += 1 |
|
axes[row, col_offset].imshow(image_np) |
|
axes[row, col_offset].imshow(pred_np, alpha=0.6, cmap='Reds') |
|
axes[row, col_offset].set_title(f"Overlay: {class_name}") |
|
axes[row, col_offset].axis('off') |
|
|
|
|
|
if has_gt and class_name in ground_truth: |
|
col_offset += 1 |
|
gt_mask = ground_truth[class_name] |
|
gt_np = gt_mask.cpu().numpy() |
|
|
|
|
|
comparison = np.zeros((*gt_np.shape, 3)) |
|
comparison[gt_np > 0.5] = [0, 1, 0] |
|
comparison[pred_np > 0.5] = [1, 0, 0] |
|
comparison[(gt_np > 0.5) & (pred_np > 0.5)] = [1, 1, 0] |
|
|
|
axes[row, col_offset].imshow(image_np) |
|
axes[row, col_offset].imshow(comparison, alpha=0.6) |
|
axes[row, col_offset].set_title(f"Comparison: {class_name}") |
|
axes[row, col_offset].axis('off') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def create_combined_mask(self, masks: Dict[str, torch.Tensor]) -> np.ndarray: |
|
"""Create a combined mask visualization for multiple classes.""" |
|
if not masks: |
|
return np.zeros((512, 512)) |
|
|
|
|
|
first_mask = list(masks.values())[0] |
|
combined = np.zeros((*first_mask.shape, 3)) |
|
|
|
for i, (class_name, mask) in enumerate(masks.items()): |
|
mask_np = mask.cpu().numpy() |
|
color = self.class_colors.get(class_name, [1, 1, 1]) |
|
|
|
|
|
for c in range(3): |
|
combined[:, :, c] += mask_np * color[c] |
|
|
|
|
|
combined = np.clip(combined, 0, 1) |
|
return combined |
|
|
|
def visualize_attention_maps( |
|
self, |
|
image: torch.Tensor, |
|
attention_maps: torch.Tensor, |
|
class_names: List[str], |
|
title: str = "Attention Maps" |
|
) -> plt.Figure: |
|
"""Visualize attention maps for different classes.""" |
|
num_classes = len(class_names) |
|
fig, axes = plt.subplots(2, num_classes, figsize=(num_classes * 4, 8)) |
|
|
|
|
|
image_np = image.permute(1, 2, 0).cpu().numpy() |
|
if image_np.min() < 0 or image_np.max() > 1: |
|
image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min()) |
|
|
|
for i in range(num_classes): |
|
axes[0, i].imshow(image_np) |
|
axes[0, i].set_title(f"Original - {class_names[i]}") |
|
axes[0, i].axis('off') |
|
|
|
|
|
attention_np = attention_maps.cpu().numpy() |
|
for i in range(min(num_classes, attention_np.shape[0])): |
|
attention_map = attention_np[i] |
|
|
|
|
|
attention_map = cv2.resize(attention_map, (image_np.shape[1], image_np.shape[0])) |
|
|
|
axes[1, i].imshow(attention_map, cmap='hot') |
|
axes[1, i].set_title(f"Attention - {class_names[i]}") |
|
axes[1, i].axis('off') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def visualize_prompt_points( |
|
self, |
|
image: torch.Tensor, |
|
prompts: List[Dict], |
|
title: str = "Prompt Points" |
|
) -> plt.Figure: |
|
"""Visualize prompt points and boxes on the image.""" |
|
fig, ax = plt.subplots(1, 1, figsize=(10, 10)) |
|
|
|
|
|
image_np = image.permute(1, 2, 0).cpu().numpy() |
|
if image_np.min() < 0 or image_np.max() > 1: |
|
image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min()) |
|
|
|
ax.imshow(image_np) |
|
|
|
|
|
colors = plt.cm.Set3(np.linspace(0, 1, len(prompts))) |
|
|
|
for i, prompt in enumerate(prompts): |
|
color = colors[i] |
|
|
|
if prompt['type'] == 'point': |
|
x, y = prompt['data'] |
|
ax.scatter(x, y, c=[color], s=100, marker='o', |
|
label=f"{prompt['class']} (point)") |
|
|
|
elif prompt['type'] == 'box': |
|
x1, y1, x2, y2 = prompt['data'] |
|
rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, |
|
linewidth=2, edgecolor=color, |
|
facecolor='none', |
|
label=f"{prompt['class']} (box)") |
|
ax.add_patch(rect) |
|
|
|
ax.set_title(title) |
|
ax.legend() |
|
ax.axis('off') |
|
|
|
return fig |
|
|
|
|
|
class ExperimentVisualizer: |
|
"""Visualization tools for experiment results and comparisons.""" |
|
|
|
def __init__(self): |
|
self.segmentation_visualizer = SegmentationVisualizer() |
|
|
|
def plot_metrics_comparison( |
|
self, |
|
results: Dict[str, List[float]], |
|
metric_name: str = "IoU", |
|
title: str = "Metrics Comparison" |
|
) -> plt.Figure: |
|
"""Plot comparison of metrics across different methods/strategies.""" |
|
fig, ax = plt.subplots(1, 1, figsize=(10, 6)) |
|
|
|
|
|
methods = list(results.keys()) |
|
values = [np.mean(results[method]) for method in methods] |
|
errors = [np.std(results[method]) for method in methods] |
|
|
|
|
|
bars = ax.bar(methods, values, yerr=errors, capsize=5, alpha=0.7) |
|
|
|
|
|
for bar, value in zip(bars, values): |
|
height = bar.get_height() |
|
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01, |
|
f'{value:.3f}', ha='center', va='bottom') |
|
|
|
ax.set_title(title) |
|
ax.set_ylabel(metric_name) |
|
ax.set_xlabel("Methods") |
|
ax.grid(True, alpha=0.3) |
|
|
|
plt.xticks(rotation=45) |
|
plt.tight_layout() |
|
|
|
return fig |
|
|
|
def plot_learning_curves( |
|
self, |
|
episode_metrics: List[Dict[str, float]], |
|
metric_name: str = "iou" |
|
) -> plt.Figure: |
|
"""Plot learning curves over episodes.""" |
|
fig, ax = plt.subplots(1, 1, figsize=(12, 6)) |
|
|
|
|
|
episodes = range(1, len(episode_metrics) + 1) |
|
values = [ep.get(metric_name, 0) for ep in episode_metrics] |
|
|
|
|
|
ax.plot(episodes, values, 'b-', linewidth=2, label=f'{metric_name.upper()}') |
|
|
|
|
|
window_size = min(10, len(values) // 4) |
|
if window_size > 1: |
|
moving_avg = np.convolve(values, np.ones(window_size)/window_size, mode='valid') |
|
ax.plot(episodes[window_size-1:], moving_avg, 'r--', linewidth=2, |
|
label=f'Moving Average (window={window_size})') |
|
|
|
ax.set_title(f"Learning Curve - {metric_name.upper()}") |
|
ax.set_xlabel("Episode") |
|
ax.set_ylabel(metric_name.upper()) |
|
ax.grid(True, alpha=0.3) |
|
ax.legend() |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def plot_shot_analysis( |
|
self, |
|
shot_results: Dict[int, List[float]], |
|
metric_name: str = "iou" |
|
) -> plt.Figure: |
|
"""Plot performance analysis across different numbers of shots.""" |
|
fig, ax = plt.subplots(1, 1, figsize=(10, 6)) |
|
|
|
|
|
shots = sorted(shot_results.keys()) |
|
means = [np.mean(shot_results[shot]) for shot in shots] |
|
stds = [np.std(shot_results[shot]) for shot in shots] |
|
|
|
|
|
ax.errorbar(shots, means, yerr=stds, marker='o', linewidth=2, |
|
capsize=5, capthick=2) |
|
|
|
ax.set_title(f"Performance vs Number of Shots - {metric_name.upper()}") |
|
ax.set_xlabel("Number of Shots") |
|
ax.set_ylabel(f"Mean {metric_name.upper()}") |
|
ax.grid(True, alpha=0.3) |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def plot_prompt_strategy_comparison( |
|
self, |
|
strategy_results: Dict[str, Dict[str, float]], |
|
metric_name: str = "mean_iou" |
|
) -> plt.Figure: |
|
"""Plot comparison of different prompt strategies.""" |
|
fig, ax = plt.subplots(1, 1, figsize=(12, 6)) |
|
|
|
|
|
strategies = list(strategy_results.keys()) |
|
values = [strategy_results[s].get(metric_name, 0) for s in strategies] |
|
errors = [strategy_results[s].get(f'std_{metric_name.split("_")[-1]}', 0) |
|
for s in strategies] |
|
|
|
|
|
bars = ax.bar(strategies, values, yerr=errors, capsize=5, alpha=0.7) |
|
|
|
|
|
for bar, value in zip(bars, values): |
|
height = bar.get_height() |
|
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01, |
|
f'{value:.3f}', ha='center', va='bottom') |
|
|
|
ax.set_title(f"Prompt Strategy Comparison - {metric_name}") |
|
ax.set_ylabel(metric_name.replace('_', ' ').title()) |
|
ax.set_xlabel("Strategy") |
|
ax.grid(True, alpha=0.3) |
|
|
|
plt.xticks(rotation=45) |
|
plt.tight_layout() |
|
|
|
return fig |
|
|
|
def create_comprehensive_report( |
|
self, |
|
experiment_results: Dict, |
|
output_dir: str, |
|
experiment_name: str = "experiment" |
|
): |
|
"""Create a comprehensive visualization report.""" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
if 'episode_metrics' in experiment_results: |
|
|
|
for metric in ['iou', 'dice', 'precision', 'recall']: |
|
fig = self.plot_learning_curves( |
|
experiment_results['episode_metrics'], |
|
metric |
|
) |
|
fig.savefig(os.path.join(output_dir, f'{experiment_name}_learning_curve_{metric}.png')) |
|
plt.close(fig) |
|
|
|
if 'class_metrics' in experiment_results: |
|
|
|
class_results = experiment_results['class_metrics'] |
|
for class_name, metrics in class_results.items(): |
|
if isinstance(metrics, list): |
|
fig = self.plot_learning_curves(metrics, 'iou') |
|
fig.savefig(os.path.join(output_dir, f'{experiment_name}_class_{class_name}.png')) |
|
plt.close(fig) |
|
|
|
if 'shot_analysis' in experiment_results: |
|
|
|
for metric in ['iou', 'dice']: |
|
fig = self.plot_shot_analysis( |
|
experiment_results['shot_analysis'], |
|
metric |
|
) |
|
fig.savefig(os.path.join(output_dir, f'{experiment_name}_shot_analysis_{metric}.png')) |
|
plt.close(fig) |
|
|
|
if 'strategy_comparison' in experiment_results: |
|
|
|
for metric in ['mean_iou', 'mean_dice']: |
|
fig = self.plot_prompt_strategy_comparison( |
|
experiment_results['strategy_comparison'], |
|
metric |
|
) |
|
fig.savefig(os.path.join(output_dir, f'{experiment_name}_strategy_comparison_{metric}.png')) |
|
plt.close(fig) |
|
|
|
print(f"Comprehensive report saved to {output_dir}") |
|
|
|
|
|
class AttentionVisualizer: |
|
"""Specialized visualizer for attention mechanisms.""" |
|
|
|
def __init__(self): |
|
self.segmentation_visualizer = SegmentationVisualizer() |
|
|
|
def visualize_cross_attention( |
|
self, |
|
image: torch.Tensor, |
|
text_tokens: List[str], |
|
attention_weights: torch.Tensor, |
|
title: str = "Cross-Attention Visualization" |
|
) -> plt.Figure: |
|
"""Visualize cross-attention between image and text tokens.""" |
|
fig, axes = plt.subplots(2, 2, figsize=(15, 12)) |
|
|
|
|
|
image_np = image.permute(1, 2, 0).cpu().numpy() |
|
if image_np.min() < 0 or image_np.max() > 1: |
|
image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min()) |
|
|
|
axes[0, 0].imshow(image_np) |
|
axes[0, 0].set_title("Original Image") |
|
axes[0, 0].axis('off') |
|
|
|
|
|
axes[0, 1].text(0.1, 0.5, '\n'.join(text_tokens), fontsize=12, |
|
verticalalignment='center') |
|
axes[0, 1].set_title("Text Tokens") |
|
axes[0, 1].axis('off') |
|
|
|
|
|
attention_np = attention_weights.cpu().numpy() |
|
sns.heatmap(attention_np, ax=axes[1, 0], cmap='viridis') |
|
axes[1, 0].set_title("Attention Heatmap") |
|
axes[1, 0].set_xlabel("Text Tokens") |
|
axes[1, 0].set_ylabel("Image Patches") |
|
|
|
|
|
|
|
attention_map = np.mean(attention_np, axis=1) |
|
attention_map = attention_map.reshape(int(np.sqrt(len(attention_map))), -1) |
|
attention_map = cv2.resize(attention_map, (image_np.shape[1], image_np.shape[0])) |
|
|
|
axes[1, 1].imshow(image_np) |
|
axes[1, 1].imshow(attention_map, alpha=0.6, cmap='hot') |
|
axes[1, 1].set_title("Attention Overlay") |
|
axes[1, 1].axis('off') |
|
|
|
plt.tight_layout() |
|
return fig |