Spaces:
Configuration error
Configuration error
from typing import Optional, Union, Tuple, List, Callable, Dict | |
from tqdm.notebook import tqdm | |
import torch | |
import math | |
from typing import List, Tuple, Union | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import os | |
import re | |
import torch | |
from IPython.display import display | |
from sklearn.cluster import KMeans | |
import matplotlib.pyplot as plt | |
from .ptp_utils import * | |
import torchvision.transforms as transforms | |
from sklearn.decomposition import PCA | |
import pickle as pkl | |
import torch.nn.functional as F | |
import argparse | |
from sklearn.metrics.cluster import adjusted_rand_score, normalized_mutual_info_score, fowlkes_mallows_score, v_measure_score | |
transform_train = transforms.Compose([ | |
transforms.ToPILImage(), | |
]) | |
pca = PCA(n_components=3) | |
def save_mask(mask, output_name): | |
mask_image = transform_train(mask.float()) | |
mask_image.save(output_name) | |
def show_image(image): | |
image = 255 * image / image.max() | |
image = image.unsqueeze(-1).expand(*image.shape, 3) | |
image = image.numpy().astype(np.uint8) | |
image = np.array(Image.fromarray(image).resize((256, 256))) | |
return image | |
def cluster2noun_mod(clusters, background_segment_threshold, num_segments, nouns, cross_attention): | |
REPEAT=clusters.shape[0]/cross_attention.shape[0] | |
result = {} | |
result_mask={} | |
nouns_indices = [index for (index, word) in nouns] | |
nouns_maps = cross_attention.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]] | |
nouns_maps = cross_attention.unsqueeze(-1).cpu().numpy() | |
normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(REPEAT, axis=0).repeat(REPEAT, axis=1) | |
for i in range(nouns_maps.shape[-1]): | |
curr_noun_map = nouns_maps[:, :, i].repeat(REPEAT, axis=0).repeat(REPEAT, axis=1) | |
normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max() | |
for c in range(num_segments): | |
cluster_mask = np.zeros_like(clusters) | |
cluster_mask[clusters == c] = 1 | |
score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))] | |
scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps] | |
result[c] = nouns[np.argmax(np.array(scores))] if max(scores) > background_segment_threshold else "BG" | |
result_mask[c]=cluster_mask | |
return result, result_mask | |
def cluster2noun_(clusters, background_segment_threshold, num_segments, nouns, cross_attention, attention_threshold=0.2): | |
REPEAT = clusters.shape[0] // cross_attention.shape[0] | |
result = {} | |
result_mask = {} | |
print('cross_attention',cross_attention.shape) | |
# 提取名词索引和对应的注意力图 | |
nouns_indices = [index for (index, word) in nouns] | |
nouns_maps = cross_attention.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]] | |
print('nouns_maps', nouns_maps.shape) | |
normalized_nouns_maps = nouns_maps | |
#normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(REPEAT, axis=0).repeat(REPEAT, axis=1) | |
# 标准化注意力图并应用阈值 | |
# for i in range(nouns_maps.shape[-1]): | |
# curr_noun_map = nouns_maps[:, :, i].repeat(REPEAT, axis=0).repeat(REPEAT, axis=1) | |
# normalized_map = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max() | |
# # 应用阈值,将低于阈值的部分设为 0 | |
# #normalized_map[normalized_map < attention_threshold] = 0 | |
# normalized_nouns_maps[:, :, i] = normalized_map | |
print('normalized_nouns_maps', normalized_nouns_maps.shape) | |
#show_normalized_nouns_maps(normalized_nouns_maps, nouns, logdir) | |
# 用于记录已经分配的单词 | |
assigned_nouns = set() | |
for c in range(num_segments): | |
cluster_mask = np.zeros_like(clusters) | |
cluster_mask[clusters == c] = 1 | |
score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))] | |
scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps] | |
# 找出最高分的名词,并确保未被分配过 | |
sorted_scores_indices = np.argsort(scores)[::-1] | |
assigned_word = None | |
for idx in sorted_scores_indices: | |
if scores[idx] > background_segment_threshold and nouns[idx] not in assigned_nouns: | |
assigned_word = nouns[idx] | |
assigned_nouns.add(nouns[idx]) # 记录这个单词已分配 | |
break | |
# 如果没有找到合适的名词,强制分配最高分的未分配名词 | |
if assigned_word is None and len(sorted_scores_indices) > 0: | |
for idx in sorted_scores_indices: | |
if nouns[idx] not in assigned_nouns: | |
assigned_word = nouns[idx] | |
assigned_nouns.add(nouns[idx]) # 记录这个单词已分配 | |
break | |
if assigned_word: | |
result[c] = assigned_word | |
result_mask[c] = cluster_mask | |
return result, result_mask | |
def aggregate_attention( attention_maps, | |
res: int, from_where: List[str], | |
is_cross: bool, select: int, prompts,): | |
out = [] | |
num_pixels = res ** 2 | |
for location in from_where: | |
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: | |
if item.shape[1] == num_pixels: | |
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] | |
out.append(cross_maps) | |
out = torch.cat(out, dim=0) | |
out = out.sum(0) / out.shape[0] | |
return out.cpu(), attention_maps | |
def cluster(self_attention, num_segments,): | |
np.random.seed(1) | |
frames = self_attention.shape[0] | |
video_clusters = [] | |
for i in range(frames): | |
per_frame_attention = self_attention[i] | |
print('per_frame_attention',per_frame_attention.shape) | |
resolution, feat_dim = per_frame_attention.shape[0], per_frame_attention.shape[-1] | |
attn = per_frame_attention.cpu().numpy().reshape(resolution ** 2, feat_dim) | |
kmeans = KMeans(n_clusters=num_segments, n_init=10).fit(attn) | |
clusters = kmeans.labels_ | |
clusters = clusters.reshape(resolution, resolution) | |
video_clusters.append(clusters) | |
return video_clusters | |
def run_clusters(avg_dict, resolution, dict_key, save_path, special_name, num_segments): | |
video_clusters = cluster(avg_dict[dict_key][resolution], num_segments,) | |
npy_name=f'cluster_{dict_key}_{resolution}_{special_name}.npy' | |
np.save(os.path.join(save_path, npy_name), video_clusters) | |
for i in range(len(video_clusters)): | |
clusters = video_clusters[i] | |
output_name=f'cluster_{dict_key}_{resolution}_{i}.png' | |
plt.imshow(clusters) | |
plt.axis('off') | |
plt.savefig(os.path.join(save_path, output_name), bbox_inches='tight', pad_inches=0) | |
def read_pkl(path,): | |
with open(path,'rb') as f: | |
dict_ = pkl.load(f) | |
return dict_ | |
def draw_pca(avg_dict, resolution, dict_key, save_path, special_name): | |
RESOLUTION=resolution | |
if avg_dict[dict_key][RESOLUTION].__len__() == 0: | |
return | |
before_pca = avg_dict[dict_key][RESOLUTION] | |
frames = before_pca.shape[0] | |
for i in range(frames): | |
frame = before_pca[i] | |
print('frame',frame.dtype) | |
if isinstance(frame, torch.Tensor): | |
frame = frame.reshape(RESOLUTION * RESOLUTION, -1).cpu().numpy() | |
else: | |
frame = frame.reshape(RESOLUTION * RESOLUTION, -1) | |
pca.fit(frame) | |
after_pca = pca.transform(frame) | |
after_pca = after_pca.reshape(RESOLUTION,RESOLUTION,-1) | |
pca_img_min = after_pca.min(axis=(0, 1)) | |
pca_img_max = after_pca.max(axis=(0, 1)) | |
pca_img = (after_pca - pca_img_min) / (pca_img_max - pca_img_min) | |
output_name=f'pca_{dict_key}_{resolution}_{i}.png' | |
pca_img = Image.fromarray((pca_img * 255).astype(np.uint8)) | |
pca_img=pca_img.resize((512,512)) | |
pca_img.save(os.path.join(save_path, output_name)) | |
def image_normalize(numpy_array, save_path,output_name): | |
numpy_array=numpy_array.cpu().numpy() | |
img_min = numpy_array.min() | |
img_max = numpy_array.max() | |
normalize_array = (numpy_array - img_min) / (img_max - img_min) | |
plt.imshow(normalize_array) | |
plt.axis('off') | |
plt.savefig(os.path.join(save_path, output_name), bbox_inches='tight', pad_inches=0) | |
def cross_cosine_with_name(resolution, inv_avg_dict, denoise_avg_dict, indice, save_path, save_crossattn=False, noun_name = ''): | |
inv_cross_attn = inv_avg_dict['attn'][resolution][:,:,indice] | |
denoise_cross_attn = denoise_avg_dict['attn'][resolution][:,:,indice] | |
if save_crossattn: | |
image_normalize(inv_cross_attn, save_path, f'crossattn_{resolution}_inv_{noun_name}.png') | |
image_normalize(denoise_cross_attn, save_path, f'crossattn_{resolution}_denoise_{noun_name}.png') | |
return F.cosine_similarity(inv_cross_attn.reshape(1,-1), denoise_cross_attn.reshape(1,-1)) | |
def cross_cosine(resolution, inv_avg_dict, denoise_avg_dict, indice, save_path, save_crossattn=False,): | |
inv_cross_attn = inv_avg_dict['attn'][resolution][:,:,indice] | |
denoise_cross_attn = denoise_avg_dict['attn'][resolution][:,:,indice] | |
if save_crossattn: | |
image_normalize(inv_cross_attn, save_path, f'crossattn_{resolution}_inv.png') | |
image_normalize(denoise_cross_attn, save_path, f'crossattn_{resolution}_denoise.png') | |
return F.cosine_similarity(inv_cross_attn.reshape(1,-1), denoise_cross_attn.reshape(1,-1)) | |
def save_crossattn(input_path, caption, inv_cross_avg_dict, denoise_cross_avg_dict, results_folder, RES=16): | |
org_image = Image.open(input_path).convert("RGB") | |
prompts=["<|startoftext|>",] + caption.split(' ') + ["<|endoftext|>",] | |
inv_crossattn = inv_cross_avg_dict['attn'][RES] | |
denoise_crossattn = denoise_cross_avg_dict['attn'][RES] | |
attn_img1, mask_img1, _ = show_cross_attention_plus_orig_img(prompts, inv_crossattn, orig_image=org_image) | |
attn_img2, mask_img2, _ = show_cross_attention_plus_orig_img(prompts, denoise_crossattn, orig_image=org_image) | |
attn_img1.save(os.path.join(results_folder,'crossattn_inv.png')) | |
attn_img2.save(os.path.join(results_folder,'crossattn_denoise.png')) | |
mask_img1.save(os.path.join(results_folder,'crossattn_inv_mask.png')) | |
mask_img2.save(os.path.join(results_folder,'crossattn_denoise_mask.png')) |