DragDiffusion / drag_bench_evaluation /run_eval_point_matching.py
GwanHyeong's picture
Upload folder using huggingface_hub
8c8af64 verified
raw
history blame
5.03 kB
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
# run evaluation of mean distance between the desired target points and the position of final handle points
import argparse
import os
import pickle
import numpy as np
import PIL
from PIL import Image
from torchvision.transforms import PILToTensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from dift_sd import SDFeaturizer
from pytorch_lightning import seed_everything
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="setting arguments")
parser.add_argument('--eval_root',
action='append',
help='root of dragging results for evaluation',
required=True)
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# using SD-2.1
dift = SDFeaturizer('stabilityai/stable-diffusion-2-1')
all_category = [
'art_work',
'land_scape',
'building_city_view',
'building_countryside_view',
'animals',
'human_head',
'human_upper_body',
'human_full_body',
'interior_design',
'other_objects',
]
original_img_root = 'drag_bench_data/'
for target_root in args.eval_root:
# fixing the seed for semantic correspondence
seed_everything(42)
all_dist = []
for cat in all_category:
for file_name in os.listdir(os.path.join(original_img_root, cat)):
if file_name == '.DS_Store':
continue
with open(os.path.join(original_img_root, cat, file_name, 'meta_data.pkl'), 'rb') as f:
meta_data = pickle.load(f)
prompt = meta_data['prompt']
points = meta_data['points']
# here, the point is in x,y coordinate
handle_points = []
target_points = []
for idx, point in enumerate(points):
# from now on, the point is in row,col coordinate
cur_point = torch.tensor([point[1], point[0]])
if idx % 2 == 0:
handle_points.append(cur_point)
else:
target_points.append(cur_point)
source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')
source_image_PIL = Image.open(source_image_path)
dragged_image_PIL = Image.open(dragged_image_path)
dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)
source_image_tensor = (PILToTensor()(source_image_PIL) / 255.0 - 0.5) * 2
dragged_image_tensor = (PILToTensor()(dragged_image_PIL) / 255.0 - 0.5) * 2
_, H, W = source_image_tensor.shape
ft_source = dift.forward(source_image_tensor,
prompt=prompt,
t=261,
up_ft_index=1,
ensemble_size=8)
ft_source = F.interpolate(ft_source, (H, W), mode='bilinear')
ft_dragged = dift.forward(dragged_image_tensor,
prompt=prompt,
t=261,
up_ft_index=1,
ensemble_size=8)
ft_dragged = F.interpolate(ft_dragged, (H, W), mode='bilinear')
cos = nn.CosineSimilarity(dim=1)
for pt_idx in range(len(handle_points)):
hp = handle_points[pt_idx]
tp = target_points[pt_idx]
num_channel = ft_source.size(1)
src_vec = ft_source[0, :, hp[0], hp[1]].view(1, num_channel, 1, 1)
cos_map = cos(src_vec, ft_dragged).cpu().numpy()[0] # H, W
max_rc = np.unravel_index(cos_map.argmax(), cos_map.shape) # the matched row,col
# calculate distance
dist = (tp - torch.tensor(max_rc)).float().norm()
all_dist.append(dist)
print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item())