Spaces:
Runtime error
Runtime error
# ************************************************************************* | |
# Copyright (2023) Bytedance Inc. | |
# | |
# Copyright (2023) DragDiffusion Authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ************************************************************************* | |
# run evaluation of mean distance between the desired target points and the position of final handle points | |
import argparse | |
import os | |
import pickle | |
import numpy as np | |
import PIL | |
from PIL import Image | |
from torchvision.transforms import PILToTensor | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dift_sd import SDFeaturizer | |
from pytorch_lightning import seed_everything | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="setting arguments") | |
parser.add_argument('--eval_root', | |
action='append', | |
help='root of dragging results for evaluation', | |
required=True) | |
args = parser.parse_args() | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
# using SD-2.1 | |
dift = SDFeaturizer('stabilityai/stable-diffusion-2-1') | |
all_category = [ | |
'art_work', | |
'land_scape', | |
'building_city_view', | |
'building_countryside_view', | |
'animals', | |
'human_head', | |
'human_upper_body', | |
'human_full_body', | |
'interior_design', | |
'other_objects', | |
] | |
original_img_root = 'drag_bench_data/' | |
for target_root in args.eval_root: | |
# fixing the seed for semantic correspondence | |
seed_everything(42) | |
all_dist = [] | |
for cat in all_category: | |
for file_name in os.listdir(os.path.join(original_img_root, cat)): | |
if file_name == '.DS_Store': | |
continue | |
with open(os.path.join(original_img_root, cat, file_name, 'meta_data.pkl'), 'rb') as f: | |
meta_data = pickle.load(f) | |
prompt = meta_data['prompt'] | |
points = meta_data['points'] | |
# here, the point is in x,y coordinate | |
handle_points = [] | |
target_points = [] | |
for idx, point in enumerate(points): | |
# from now on, the point is in row,col coordinate | |
cur_point = torch.tensor([point[1], point[0]]) | |
if idx % 2 == 0: | |
handle_points.append(cur_point) | |
else: | |
target_points.append(cur_point) | |
source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png') | |
dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png') | |
source_image_PIL = Image.open(source_image_path) | |
dragged_image_PIL = Image.open(dragged_image_path) | |
dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR) | |
source_image_tensor = (PILToTensor()(source_image_PIL) / 255.0 - 0.5) * 2 | |
dragged_image_tensor = (PILToTensor()(dragged_image_PIL) / 255.0 - 0.5) * 2 | |
_, H, W = source_image_tensor.shape | |
ft_source = dift.forward(source_image_tensor, | |
prompt=prompt, | |
t=261, | |
up_ft_index=1, | |
ensemble_size=8) | |
ft_source = F.interpolate(ft_source, (H, W), mode='bilinear') | |
ft_dragged = dift.forward(dragged_image_tensor, | |
prompt=prompt, | |
t=261, | |
up_ft_index=1, | |
ensemble_size=8) | |
ft_dragged = F.interpolate(ft_dragged, (H, W), mode='bilinear') | |
cos = nn.CosineSimilarity(dim=1) | |
for pt_idx in range(len(handle_points)): | |
hp = handle_points[pt_idx] | |
tp = target_points[pt_idx] | |
num_channel = ft_source.size(1) | |
src_vec = ft_source[0, :, hp[0], hp[1]].view(1, num_channel, 1, 1) | |
cos_map = cos(src_vec, ft_dragged).cpu().numpy()[0] # H, W | |
max_rc = np.unravel_index(cos_map.argmax(), cos_map.shape) # the matched row,col | |
# calculate distance | |
dist = (tp - torch.tensor(max_rc)).float().norm() | |
all_dist.append(dist) | |
print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item()) | |