Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import Tensor, nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.decomposition import PCA | |
| class RandomAffineAndRetMat(torch.nn.Module): | |
| def __init__( | |
| self, | |
| degrees, | |
| translate=None, | |
| scale=None, | |
| shear=None, | |
| interpolation=torchvision.transforms.InterpolationMode.NEAREST, | |
| fill=0, | |
| center=None, | |
| ): | |
| super().__init__() | |
| self.degrees = degrees | |
| self.translate = translate | |
| self.scale = scale | |
| self.shear = shear | |
| self.interpolation = interpolation | |
| self.fill = fill | |
| self.center = center | |
| def forward(self, img): | |
| """ | |
| img (PIL Image or Tensor): Image to be transformed. | |
| Returns: | |
| PIL Image or Tensor: Affine transformed image. | |
| """ | |
| fill = self.fill | |
| if isinstance(img, Tensor): | |
| if isinstance(fill, (int, float)): | |
| fill = [float(fill)] * transforms.functional.get_image_num_channels(img) | |
| else: | |
| fill = [float(f) for f in fill] | |
| img_size = transforms.functional.get_image_size(img) | |
| ret = transforms.RandomAffine.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) | |
| transformed_image = transforms.functional.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center) | |
| affine_matrix = self.get_affine_matrix_from_params(ret) | |
| return transformed_image, affine_matrix | |
| def get_affine_matrix_from_params(self, params): | |
| degrees, translate, scale, shear = params | |
| degrees = torch.tensor(degrees) | |
| shear = torch.tensor(shear) | |
| # パラメータを変換行列に変換 | |
| rotation_matrix = torch.tensor([[torch.cos(torch.deg2rad(degrees)), -torch.sin(torch.deg2rad(degrees)), 0], | |
| [torch.sin(torch.deg2rad(degrees)), torch.cos(torch.deg2rad(degrees)), 0], | |
| [0, 0, 1]]) | |
| translation_matrix = torch.tensor([[1, 0, translate[0]], | |
| [0, 1, translate[1]], | |
| [0, 0, 1]]).to(torch.float32) | |
| scaling_matrix = torch.tensor([[scale, 0, 0], | |
| [0, scale, 0], | |
| [0, 0, 1]]) | |
| shearing_matrix = torch.tensor([[1, -torch.tan(torch.deg2rad(shear[0])), 0], | |
| [-torch.tan(torch.deg2rad(shear[1])), 1, 0], | |
| [0, 0, 1]]) | |
| # 変換行列を合成 | |
| affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix) | |
| return affine_matrix | |
| class GetTransformedCoords(nn.Module): | |
| def __init__(self, affine_matrix, center): | |
| super().__init__() | |
| self.affine_matrix = affine_matrix | |
| self.center = center | |
| def forward(self, _coords): | |
| # coords: like tensor([[43, 26], [44, 27], [45, 28]]) | |
| center_x, center_y = self.center | |
| # 元の座標を中心原点にシフト | |
| coords = _coords.clone() | |
| coords[:, 0] -= center_x | |
| coords[:, 1] -= center_y | |
| # 各バッチに対して変換を行う | |
| homogeneous_coordinates = torch.cat([coords, torch.ones(coords.shape[0], 1, dtype=torch.float32, device=coords.device)], dim=1) | |
| transformed_coordinates = torch.bmm(self.affine_matrix, homogeneous_coordinates.unsqueeze(-1)).squeeze(-1) | |
| # 画像の範囲内に収める | |
| # transformed_x = max(0, min(width - 1, transformed_coordinates[:, 0])) | |
| # transformed_y = max(0, min(height - 1, transformed_coordinates[:, 1])) | |
| transformed_x = transformed_coordinates[:, 0] | |
| transformed_y = transformed_coordinates[:, 1] | |
| transformed_x += center_x | |
| transformed_y += center_y | |
| return torch.stack([transformed_x, transformed_y]).t().to(torch.long) | |
| # ルートを取らないpairwise_distanceのバージョン | |
| def pairwise_distance_squared(a, b): | |
| return torch.sum((a - b) ** 2, dim=-1) | |
| def cosine_similarity(a, b): | |
| # ベクトルaとbの内積を計算 | |
| dot_product = torch.matmul(a, b) | |
| # ベクトルaとbのノルム(大きさ)を計算 | |
| norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1)) | |
| norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1)) | |
| # コサイン類似度を計算(内積をノルムの積で割る) | |
| return dot_product / (norm_a * norm_b) | |
| def batch_cosine_similarity(a, b): | |
| # ベクトルaとbの内積を計算 | |
| dot_product = torch.einsum('bnd,bnd->bn', a, b) | |
| # ベクトルaとbのノルム(大きさ)を計算 | |
| norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1)) | |
| norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1)) | |
| # コサイン類似度を計算(内積をノルムの積で割る) | |
| return dot_product / (norm_a * norm_b) | |
| class TripletLossBatch(nn.Module): | |
| def __init__(self): | |
| super(TripletLossBatch, self).__init__() | |
| def forward(self, anchor, positive, negative, margin=1.0): | |
| distance_positive = F.pairwise_distance(anchor, positive, p=2) | |
| distance_negative = F.pairwise_distance(anchor, negative, p=2) | |
| losses = torch.relu(distance_positive - distance_negative + margin) | |
| return losses.mean() | |
| class TripletLossCosineSimilarity(nn.Module): | |
| def __init__(self): | |
| super(TripletLossCosineSimilarity, self).__init__() | |
| def forward(self, anchor, positive, negative, margin=1.0): | |
| distance_positive = 1 - batch_cosine_similarity(anchor, positive) | |
| distance_negative = 1 - batch_cosine_similarity(anchor, negative) | |
| losses = torch.relu(distance_positive - distance_negative + margin) | |
| return losses.mean() | |
| def imsave(img): | |
| img = torchvision.utils.make_grid(img) | |
| img = img / 2 + 0.5 | |
| npimg = img.detach().cpu().numpy() | |
| # plt.imshow(np.transpose(npimg, (1, 2, 0))) | |
| # plt.show() | |
| # save image | |
| npimg = np.transpose(npimg, (1, 2, 0)) | |
| npimg = npimg * 255 | |
| npimg = npimg.astype(np.uint8) | |
| Image.fromarray(npimg).save('sample.png') | |
| def norm_img(img): | |
| return (img-img.min())/(img.max()-img.min()) | |
| def norm_img2(img): | |
| return (img-img.min())/(img.max()-img.min())*255 | |
| class DistanceMapLogger: | |
| def __call__(self, img, feature_map, save_path, x_coords=None, y_coords=None): | |
| device = feature_map.device | |
| batch_size = feature_map.size(0) | |
| feature_dim = feature_map.size(1) | |
| image_size = feature_map.size(2) | |
| if x_coords is None: | |
| x_coords = [69]*batch_size | |
| if y_coords is None: | |
| y_coords = [42]*batch_size | |
| # PCAで3次元のマップを抽出 | |
| pca = PCA(n_components=3) | |
| pca_result = pca.fit_transform(feature_map.permute(0,2,3,1).reshape(-1,feature_dim).detach().cpu().numpy()) # PCA を実行 | |
| reshaped_pca_result = pca_result.reshape(batch_size,image_size,image_size,3) # 3次元に変換(元は1次元) | |
| sample_num = 0 | |
| vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整 | |
| vector = vectors[sample_num] | |
| # バッチ内の各特徴マップに対して内積を計算 | |
| # feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化 | |
| reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim) | |
| batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size) | |
| # batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size) | |
| norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2 | |
| # norm_batch_distance_map[:,0,0] = 0.001 | |
| # 可視化と保存 | |
| fig, axes = plt.subplots(5, 4, figsize=(20, 25)) | |
| for ax in axes.flatten(): | |
| ax.axis('off') | |
| # 余白をなくす | |
| plt.subplots_adjust(wspace=0, hspace=0) | |
| # 外の余白もなくす | |
| plt.subplots_adjust(left=0, right=1, bottom=0, top=1) | |
| # 距離マップの可視化 | |
| for i in range(5): | |
| axes[i, 0].imshow(norm_batch_distance_map[i].detach().cpu(), cmap='hot') | |
| if i == sample_num: | |
| axes[i, 0].scatter(x_coords[i], y_coords[i], c='b', s=7) | |
| distance_map = torch.cat(((norm_batch_distance_map[i]/norm_batch_distance_map[i].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device))) | |
| alpha = 0.9 # Transparency factor for the heatmap overlay | |
| blended_tensor = (1 - alpha) * img[i] + alpha * distance_map | |
| axes[i, 1].imshow(norm_img(blended_tensor.permute(1,2,0).detach().cpu())) | |
| axes[i, 2].imshow(norm_img(img[i].permute(1,2,0).detach().cpu())) | |
| axes[i, 3].imshow(norm_img(reshaped_pca_result[i])) | |
| plt.savefig(save_path) | |
| def get_heatmaps(self, img, feature_map, source_num=0, target_num=1, x_coords=69, y_coords=42): | |
| device = feature_map.device | |
| batch_size = feature_map.size(0) | |
| feature_dim = feature_map.size(1) | |
| image_size = feature_map.size(2) | |
| x_coords = [x_coords]*batch_size | |
| y_coords = [y_coords]*batch_size | |
| vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整 | |
| vector = vectors[source_num] | |
| # バッチ内の各特徴マップに対して内積を計算 | |
| # feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化 | |
| reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim) | |
| batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size) | |
| # batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size) | |
| norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2 | |
| # norm_batch_distance_map[:,0,0] = 0.001 | |
| source_map = norm_batch_distance_map[source_num] | |
| target_map = norm_batch_distance_map[target_num] | |
| alpha = 0.9 | |
| blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num]/norm_batch_distance_map[source_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device))) | |
| blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num]/norm_batch_distance_map[target_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device))) | |
| return source_map, target_map, blended_source, blended_target |