diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..d41a94043c43d58b1c21d58399d4cc5eaf2a2606 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..aa6c95de80d6ab950281ba9474ac52afb51c83a2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +annotator/ckpts/** +result/** +trash/** \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..dab0a2691396495842018a990a7a31cfe781058c --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 ST-Modulator authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/__pycache__/ptp_utils_null_text_inversion.cpython-310.pyc b/__pycache__/ptp_utils_null_text_inversion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72189c1cb33464c411420948eb04bc8ca21aa46f Binary files /dev/null and b/__pycache__/ptp_utils_null_text_inversion.cpython-310.pyc differ diff --git a/__pycache__/ptp_utils_null_text_inversion.cpython-38.pyc b/__pycache__/ptp_utils_null_text_inversion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dee2c11be89b6ce6fbbc5af9eb8a9754ad75a22 Binary files /dev/null and b/__pycache__/ptp_utils_null_text_inversion.cpython-38.pyc differ diff --git a/__pycache__/utils.cpython-310.pyc b/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f75bdc530d590d77443e15595e2729e9b963fb50 Binary files /dev/null and b/__pycache__/utils.cpython-310.pyc differ diff --git a/__pycache__/xformers.cpython-310.pyc b/__pycache__/xformers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da513aa0483e57a23de04c9e727851218968252c Binary files /dev/null and b/__pycache__/xformers.cpython-310.pyc differ diff --git a/annotator/__pycache__/util.cpython-310.pyc b/annotator/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46828fe6843f9bdb563490855849e73236ca06e0 Binary files /dev/null and b/annotator/__pycache__/util.cpython-310.pyc differ diff --git a/annotator/__pycache__/util.cpython-39.pyc b/annotator/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d63a3a006ee95cec7b6f4ba1b7fb3216eb99b3b4 Binary files /dev/null and b/annotator/__pycache__/util.cpython-39.pyc differ diff --git a/annotator/canny/__init__.py b/annotator/canny/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0da951dc838ec9dec2131007e036113281800b --- /dev/null +++ b/annotator/canny/__init__.py @@ -0,0 +1,6 @@ +import cv2 + + +class CannyDetector: + def __call__(self, img, low_threshold, high_threshold): + return cv2.Canny(img, low_threshold, high_threshold) diff --git a/annotator/canny/__pycache__/__init__.cpython-39.pyc b/annotator/canny/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64b98290015603ca5ac49f908537aa757d8a9db3 Binary files /dev/null and b/annotator/canny/__pycache__/__init__.cpython-39.pyc differ diff --git a/annotator/dwpose/__init__.py b/annotator/dwpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91b7715b11894155744fc0679ebd5108716f6357 --- /dev/null +++ b/annotator/dwpose/__init__.py @@ -0,0 +1,71 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import torch +import numpy as np +from . import util +from .wholebody import Wholebody + +def draw_pose(pose, H, W, draw_body=True, draw_hand=True, draw_face=True): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + if draw_body: + canvas = util.draw_bodypose(canvas, candidate, subset) + + if draw_hand: + canvas = util.draw_handpose(canvas, hands) + + if draw_face: + canvas = util.draw_facepose(canvas, faces) + + return canvas + + +class DWposeDetector: + def __init__(self): + + self.pose_estimation = Wholebody() + + def __call__(self, oriImg,hand=False, face=False): + oriImg = oriImg.copy() + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.pose_estimation(oriImg) + nums, keys, locs = candidate.shape + candidate[..., 0] /= float(W) + candidate[..., 1] /= float(H) + body = candidate[:,:18].copy() + body = body.reshape(nums*18, locs) + score = subset[:,:18] + for i in range(len(score)): + for j in range(len(score[i])): + if score[i][j] > 0.3: + score[i][j] = int(18*i+j) + else: + score[i][j] = -1 + + un_visible = subset<0.3 + candidate[un_visible] = -1 + + foot = candidate[:,18:24] + + faces = candidate[:,24:92] + + hands = candidate[:,92:113] + hands = np.vstack([hands, candidate[:,113:]]) + + bodies = dict(candidate=body, subset=score) + pose = dict(bodies=bodies, hands=hands, faces=faces) + + return draw_pose(pose, H, W, draw_body=True, draw_hand=hand, draw_face=face) diff --git a/annotator/dwpose/__pycache__/__init__.cpython-310.pyc b/annotator/dwpose/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d656faa76f0f3d529f555d00b509aec35f9df47a Binary files /dev/null and b/annotator/dwpose/__pycache__/__init__.cpython-310.pyc differ diff --git a/annotator/dwpose/__pycache__/__init__.cpython-38.pyc b/annotator/dwpose/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea73f6a0f27958140c0577695ca23322a42aa822 Binary files /dev/null and b/annotator/dwpose/__pycache__/__init__.cpython-38.pyc differ diff --git a/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc b/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a826180393fc91e2aa959ed6b9de45ab952c1059 Binary files /dev/null and b/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc differ diff --git a/annotator/dwpose/__pycache__/onnxdet.cpython-38.pyc b/annotator/dwpose/__pycache__/onnxdet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbde7cd417837996b1b6d43f20151c4c6ae7bed6 Binary files /dev/null and b/annotator/dwpose/__pycache__/onnxdet.cpython-38.pyc differ diff --git a/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc b/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38cdd0a2865ce8802c570028ef13bd0df70bd568 Binary files /dev/null and b/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc differ diff --git a/annotator/dwpose/__pycache__/onnxpose.cpython-38.pyc b/annotator/dwpose/__pycache__/onnxpose.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ca9a6d8a508a48ed395e9fb78cfa2e25fce3695 Binary files /dev/null and b/annotator/dwpose/__pycache__/onnxpose.cpython-38.pyc differ diff --git a/annotator/dwpose/__pycache__/util.cpython-310.pyc b/annotator/dwpose/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e712199d91b577f54e8ffc92d2c5d42c85094e46 Binary files /dev/null and b/annotator/dwpose/__pycache__/util.cpython-310.pyc differ diff --git a/annotator/dwpose/__pycache__/util.cpython-38.pyc b/annotator/dwpose/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fabe9534aa89c5d0a1aa5f1a6ced77e0b9cd7fc7 Binary files /dev/null and b/annotator/dwpose/__pycache__/util.cpython-38.pyc differ diff --git a/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc b/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4faac40aa29e66e4fc6f5e5775db685fcd08438f Binary files /dev/null and b/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc differ diff --git a/annotator/dwpose/__pycache__/wholebody.cpython-38.pyc b/annotator/dwpose/__pycache__/wholebody.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..137746ba5a74fa4a21110155871f023654c9e838 Binary files /dev/null and b/annotator/dwpose/__pycache__/wholebody.cpython-38.pyc differ diff --git a/annotator/dwpose/dwpose_config/dwpose-l_384x288.py b/annotator/dwpose/dwpose_config/dwpose-l_384x288.py new file mode 100644 index 0000000000000000000000000000000000000000..c93a8ec09a21ab107abc623e2185d0b09f500df0 --- /dev/null +++ b/annotator/dwpose/dwpose_config/dwpose-l_384x288.py @@ -0,0 +1,257 @@ +# runtime +max_epochs = 270 +stage2_num_epochs = 30 +base_lr = 4e-3 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=21) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(288, 384), + sigma=(6., 6.93), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=1., + widen_factor=1., + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa + )), + head=dict( + type='RTMCCHead', + in_channels=1024, + out_channels=133, + input_size=codec['input_size'], + in_featuremap_size=(9, 12), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True, )) + +# base dataset settings +dataset_type = 'CocoWholeBodyDataset' +data_mode = 'topdown' +data_root = '/data/' + +backend_args = dict(backend='local') +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/', +# f'{data_root}': 's3://openmmlab/datasets/detection/coco/' +# })) + +# pipelines +train_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[0.75, 1.25], + rotate_factor=60), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] + +datasets = [] +dataset_coco=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_train_v1.0.json', + data_prefix=dict(img='coco/train2017/'), + pipeline=[], +) +datasets.append(dataset_coco) + +scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', + 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', + 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] + +for i in range(len(scene)): + datasets.append( + dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', + data_prefix=dict(img='UBody/images/'+scene[i]+'/'), + pipeline=[], + ) + ) + +# data loaders +train_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=datasets, + pipeline=train_pipeline, + test_mode=False, + )) +val_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='coco/annotations/coco_wholebody_val_v1.0.json', + bbox_file=f'{data_root}coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='coco/val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# hooks +default_hooks = dict( + checkpoint=dict( + save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] + +# evaluators +val_evaluator = dict( + type='CocoWholeBodyMetric', + ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') +test_evaluator = val_evaluator \ No newline at end of file diff --git a/annotator/dwpose/onnxdet.py b/annotator/dwpose/onnxdet.py new file mode 100644 index 0000000000000000000000000000000000000000..e0411c96a5eef41e981bde5481ef7786b242f1fa --- /dev/null +++ b/annotator/dwpose/onnxdet.py @@ -0,0 +1,125 @@ +import cv2 +import numpy as np + +import onnxruntime + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def demo_postprocess(outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(session, oriImg): + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + + ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} + output = session.run(None, ort_inputs) + predictions = demo_postprocess(output[0], input_shape)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + if dets is not None: + final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] + isscore = final_scores>0.3 + iscat = final_cls_inds == 0 + isbbox = [ i and j for (i, j) in zip(isscore, iscat)] + final_boxes = final_boxes[isbbox] + else: + final_boxes = np.array([]) + + return final_boxes diff --git a/annotator/dwpose/onnxpose.py b/annotator/dwpose/onnxpose.py new file mode 100644 index 0000000000000000000000000000000000000000..79cd4a06241123af81ea22446a4ca8816716443f --- /dev/null +++ b/annotator/dwpose/onnxpose.py @@ -0,0 +1,360 @@ +from typing import List, Tuple + +import cv2 +import numpy as np +import onnxruntime as ort + +def preprocess( + img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Do preprocessing for RTMPose model inference. + + Args: + img (np.ndarray): Input image in shape. + input_size (tuple): Input image size in shape (w, h). + + Returns: + tuple: + - resized_img (np.ndarray): Preprocessed image. + - center (np.ndarray): Center of image. + - scale (np.ndarray): Scale of image. + """ + # get shape of image + img_shape = img.shape[:2] + out_img, out_center, out_scale = [], [], [] + if len(out_bbox) == 0: + out_bbox = [[0, 0, img_shape[1], img_shape[0]]] + for i in range(len(out_bbox)): + x0 = out_bbox[i][0] + y0 = out_bbox[i][1] + x1 = out_bbox[i][2] + y1 = out_bbox[i][3] + bbox = np.array([x0, y0, x1, y1]) + + # get center and scale + center, scale = bbox_xyxy2cs(bbox, padding=1.25) + + # do affine transformation + resized_img, scale = top_down_affine(input_size, scale, center, img) + + # normalize image + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + resized_img = (resized_img - mean) / std + + out_img.append(resized_img) + out_center.append(center) + out_scale.append(scale) + + return out_img, out_center, out_scale + + +def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: + """Inference RTMPose model. + + Args: + sess (ort.InferenceSession): ONNXRuntime session. + img (np.ndarray): Input image in shape. + + Returns: + outputs (np.ndarray): Output of RTMPose model. + """ + all_out = [] + # build input + for i in range(len(img)): + input = [img[i].transpose(2, 0, 1)] + + # build output + sess_input = {sess.get_inputs()[0].name: input} + sess_output = [] + for out in sess.get_outputs(): + sess_output.append(out.name) + + # run model + outputs = sess.run(sess_output, sess_input) + all_out.append(outputs) + + return all_out + + +def postprocess(outputs: List[np.ndarray], + model_input_size: Tuple[int, int], + center: Tuple[int, int], + scale: Tuple[int, int], + simcc_split_ratio: float = 2.0 + ) -> Tuple[np.ndarray, np.ndarray]: + """Postprocess for RTMPose model output. + + Args: + outputs (np.ndarray): Output of RTMPose model. + model_input_size (tuple): RTMPose model Input image size. + center (tuple): Center of bbox in shape (x, y). + scale (tuple): Scale of bbox in shape (w, h). + simcc_split_ratio (float): Split ratio of simcc. + + Returns: + tuple: + - keypoints (np.ndarray): Rescaled keypoints. + - scores (np.ndarray): Model predict scores. + """ + all_key = [] + all_score = [] + for i in range(len(outputs)): + # use simcc to decode + simcc_x, simcc_y = outputs[i] + keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) + + # rescale keypoints + keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 + all_key.append(keypoints[0]) + all_score.append(scores[0]) + + return np.array(all_key), np.array(all_score) + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + # get bbox center and scale + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def _fix_aspect_ratio(bbox_scale: np.ndarray, + aspect_ratio: float) -> np.ndarray: + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + # compute transformation matrix + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + # get four corners of the src rectangle in the original image + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + # get four corners of the dst rectangle in the input image + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return warp_mat + + +def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, + img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get the bbox image as the model input by affine transform. + + Args: + input_size (dict): The input size of the model. + bbox_scale (dict): The bbox scale of the img. + bbox_center (dict): The bbox center of the img. + img (np.ndarray): The original image. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: img after affine transform. + - np.ndarray[float32]: bbox scale after affine transform. + """ + w, h = input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) + + # get the affine matrix + center = bbox_center + scale = bbox_scale + rot = 0 + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # do affine transform + img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + return img, bbox_scale + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + + # get maximum value locations + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + # get maximum value across x and y axis + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + # reshape + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, + simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: + """Modulate simcc distribution with Gaussian. + + Args: + simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. + simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. + simcc_split_ratio (int): The split ratio of simcc. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) + - np.ndarray[float32]: scores in shape (K,) or (n, K) + """ + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + keypoints /= simcc_split_ratio + + return keypoints, scores + + +def inference_pose(session, out_bbox, oriImg): + h, w = session.get_inputs()[0].shape[2:] + model_input_size = (w, h) + resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) + outputs = inference(session, resized_img) + keypoints, scores = postprocess(outputs, model_input_size, center, scale) + + return keypoints, scores \ No newline at end of file diff --git a/annotator/dwpose/util.py b/annotator/dwpose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..73d7d0153b38d143eb8090e07a9784a274b619ed --- /dev/null +++ b/annotator/dwpose/util.py @@ -0,0 +1,297 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for peaks in all_hand_peaks: + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/annotator/dwpose/wholebody.py b/annotator/dwpose/wholebody.py new file mode 100644 index 0000000000000000000000000000000000000000..a47c279e71db052ed3406ad2caa1b61cfddaee0d --- /dev/null +++ b/annotator/dwpose/wholebody.py @@ -0,0 +1,142 @@ +import cv2 +import numpy as np + +import onnxruntime as ort +from .onnxdet import inference_detector +from .onnxpose import inference_pose + +class Wholebody: + def __init__(self): + device = 'cuda:0' + providers = ['CPUExecutionProvider' + ] if device == 'cpu' else ['CUDAExecutionProvider'] + onnx_det = 'annotator/ckpts/yolox_l.onnx' + onnx_pose = 'annotator/ckpts/dw-ll_ucoco_384.onnx' + + self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) + self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) + + def __call__(self, oriImg): + det_result = inference_detector(self.session_det, oriImg) + keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) + + keypoints_info = np.concatenate( + (keypoints, scores[..., None]), axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores = keypoints_info[ + ..., :2], keypoints_info[..., 2] + + return keypoints, scores + + + + +# # Copyright (c) OpenMMLab. All rights reserved. +# import numpy as np +# from . import util +# import cv2 +# import mmcv +# import torch +# import matplotlib.pyplot as plt +# from mmpose.apis import inference_topdown +# from mmpose.apis import init_model as init_pose_estimator +# from mmpose.evaluation.functional import nms +# from mmpose.utils import adapt_mmdet_pipeline +# from mmpose.structures import merge_data_samples + +# from mmdet.apis import inference_detector, init_detector + + +# class Wholebody: +# def __init__(self): +# device = 'cuda:0' +# det_config = 'annotator/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py' +# det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' +# pose_config = 'annotator/dwpose/dwpose_config/dwpose-l_384x288.py' +# pose_ckpt = 'annotator/ckpts/dw-ll_ucoco_384.pth' + +# # build detector +# self.detector = init_detector(det_config, det_ckpt, device=device) +# self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg) + +# # build pose estimator +# self.pose_estimator = init_pose_estimator( +# pose_config, +# pose_ckpt, +# device=device) + +# def __call__(self, oriImg): +# # predict bbox +# det_result = inference_detector(self.detector, oriImg) +# pred_instance = det_result.pred_instances.cpu().numpy() +# bboxes = np.concatenate( +# (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) +# bboxes = bboxes[np.logical_and(pred_instance.labels == 0, +# pred_instance.scores > 0.3)] +# # # max value +# # if len(bboxes) > 0: +# # bboxes = bboxes[0].reshape(1,-1) +# bboxes = bboxes[nms(bboxes, 0.3), :4] + +# # predict keypoints +# if len(bboxes) == 0: +# pose_results = inference_topdown(self.pose_estimator, oriImg) +# else: +# pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes) +# preds = merge_data_samples(pose_results) +# preds = preds.pred_instances + +# # preds = pose_results[0].pred_instances +# keypoints = preds.get('transformed_keypoints', +# preds.keypoints) +# if 'keypoint_scores' in preds: +# scores = preds.keypoint_scores +# else: +# scores = np.ones(keypoints.shape[:-1]) + +# if 'keypoints_visible' in preds: +# visible = preds.keypoints_visible +# else: +# visible = np.ones(keypoints.shape[:-1]) +# keypoints_info = np.concatenate( +# (keypoints, scores[..., None], visible[..., None]), +# axis=-1) +# # compute neck joint +# neck = np.mean(keypoints_info[:, [5, 6]], axis=1) +# # neck score when visualizing pred +# neck[:, 2:4] = np.logical_and( +# keypoints_info[:, 5, 2:4] > 0.3, +# keypoints_info[:, 6, 2:4] > 0.3).astype(int) +# new_keypoints_info = np.insert( +# keypoints_info, 17, neck, axis=1) +# mmpose_idx = [ +# 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 +# ] +# openpose_idx = [ +# 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 +# ] +# new_keypoints_info[:, openpose_idx] = \ +# new_keypoints_info[:, mmpose_idx] +# keypoints_info = new_keypoints_info + +# keypoints, scores, visible = keypoints_info[ +# ..., :2], keypoints_info[..., 2], keypoints_info[..., 3] + +# return keypoints, scores \ No newline at end of file diff --git a/annotator/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py b/annotator/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..7c26c1b1a2896e1b852903335ac13116bd6efae3 --- /dev/null +++ b/annotator/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py @@ -0,0 +1,245 @@ +img_scale = (640, 640) # width, height + +# model settings +model = dict( + type='YOLOX', + data_preprocessor=dict( + type='DetDataPreprocessor', + pad_size_divisor=32, + batch_augments=[ + dict( + type='BatchSyncRandomResize', + random_size_range=(480, 800), + size_divisor=32, + interval=10) + ]), + backbone=dict( + type='CSPDarknet', + deepen_factor=1.0, + widen_factor=1.0, + out_indices=(2, 3, 4), + use_depthwise=False, + spp_kernal_sizes=(5, 9, 13), + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + ), + neck=dict( + type='YOLOXPAFPN', + in_channels=[256, 512, 1024], + out_channels=256, + num_csp_blocks=3, + use_depthwise=False, + upsample_cfg=dict(scale_factor=2, mode='nearest'), + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish')), + bbox_head=dict( + type='YOLOXHead', + num_classes=80, + in_channels=256, + feat_channels=256, + stacked_convs=2, + strides=(8, 16, 32), + use_depthwise=False, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='Swish'), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_bbox=dict( + type='IoULoss', + mode='square', + eps=1e-16, + reduction='sum', + loss_weight=5.0), + loss_obj=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)), + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + # In order to align the source code, the threshold of the val phase is + # 0.01, and the threshold of the test phase is 0.001. + test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + +# dataset settings +data_root = 'data/coco/' +dataset_type = 'CocoDataset' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), + dict( + type='RandomAffine', + scaling_ratio_range=(0.1, 2), + # img_scale is (width, height) + border=(-img_scale[0] // 2, -img_scale[1] // 2)), + dict( + type='MixUp', + img_scale=img_scale, + ratio_range=(0.8, 1.6), + pad_val=114.0), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + # According to the official implementation, multi-scale + # training is not considered here but in the + # 'mmdet/models/detectors/yolox.py'. + # Resize and Pad are for the last 15 epochs when Mosaic, + # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook. + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + # If the image is three-channel, the pad value needs + # to be set separately for each channel. + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), + dict(type='PackDetInputs') +] + +train_dataset = dict( + # use MultiImageMixDataset wrapper to support mosaic and mixup + type='MultiImageMixDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=[ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='LoadAnnotations', with_bbox=True) + ], + filter_cfg=dict(filter_empty_gt=False, min_size=32), + backend_args=backend_args), + pipeline=train_pipeline) + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict( + batch_size=8, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=train_dataset) +val_dataloader = dict( + batch_size=8, + num_workers=4, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'annotations/instances_val2017.json', + metric='bbox', + backend_args=backend_args) +test_evaluator = val_evaluator + +# training settings +max_epochs = 300 +num_last_epochs = 15 +interval = 10 + +train_cfg = dict(max_epochs=max_epochs, val_interval=interval) + +# optimizer +# default 8 gpu +base_lr = 0.01 +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4, + nesterov=True), + paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) + +# learning rate +param_scheduler = [ + dict( + # use quadratic formula to warm up 5 epochs + # and lr is updated by iteration + # TODO: fix default scope in get function + type='mmdet.QuadraticWarmupLR', + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True), + dict( + # use cosine lr from 5 to 285 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=5, + T_max=max_epochs - num_last_epochs, + end=max_epochs - num_last_epochs, + by_epoch=True, + convert_to_iter_based=True), + dict( + # use fixed lr during last 15 epochs + type='ConstantLR', + by_epoch=True, + factor=1, + begin=max_epochs - num_last_epochs, + end=max_epochs, + ) +] + +default_hooks = dict( + checkpoint=dict( + interval=interval, + max_keep_ckpts=3 # only keep latest 3 checkpoints + )) + +custom_hooks = [ + dict( + type='YOLOXModeSwitchHook', + num_last_epochs=num_last_epochs, + priority=48), + dict(type='SyncNormHook', priority=48), + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0001, + update_buffers=True, + priority=49) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (8 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) \ No newline at end of file diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56532c374df5c26f9ec53e2ac0dd924f4534bbdd --- /dev/null +++ b/annotator/hed/__init__.py @@ -0,0 +1,132 @@ +import numpy as np +import cv2 +import os +import torch +from einops import rearrange +from annotator.util import annotator_ckpts_path + + +class Network(torch.nn.Module): + def __init__(self, model_path): + super().__init__() + + self.netVggOne = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.netVggTwo = torch.nn.Sequential( + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.netVggThr = torch.nn.Sequential( + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.netVggFou = torch.nn.Sequential( + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.netVggFiv = torch.nn.Sequential( + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0) + self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0) + self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0) + self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) + self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) + + self.netCombine = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0), + torch.nn.Sigmoid() + ) + + self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()}) + + def forward(self, tenInput): + tenInput = tenInput * 255.0 + tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1) + + tenVggOne = self.netVggOne(tenInput) + tenVggTwo = self.netVggTwo(tenVggOne) + tenVggThr = self.netVggThr(tenVggTwo) + tenVggFou = self.netVggFou(tenVggThr) + tenVggFiv = self.netVggFiv(tenVggFou) + + tenScoreOne = self.netScoreOne(tenVggOne) + tenScoreTwo = self.netScoreTwo(tenVggTwo) + tenScoreThr = self.netScoreThr(tenVggThr) + tenScoreFou = self.netScoreFou(tenVggFou) + tenScoreFiv = self.netScoreFiv(tenVggFiv) + + tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False) + tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False) + tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False) + tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False) + tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False) + + return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1)) + + +class HEDdetector: + def __init__(self): + remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth" + modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + self.netNetwork = Network(modelpath).cuda().eval() + + def __call__(self, input_image): + assert input_image.ndim == 3 + input_image = input_image[:, :, ::-1].copy() + with torch.no_grad(): + image_hed = torch.from_numpy(input_image).float().cuda() + image_hed = image_hed / 255.0 + image_hed = rearrange(image_hed, 'h w c -> 1 c h w') + edge = self.netNetwork(image_hed)[0] + edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8) + return edge[0] + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z diff --git a/annotator/hed/__pycache__/__init__.cpython-39.pyc b/annotator/hed/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a8d2100832ddd05b48a63924e3cd1ece3d03f0e Binary files /dev/null and b/annotator/hed/__pycache__/__init__.cpython-39.pyc differ diff --git a/annotator/midas/__init__.py b/annotator/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc5ac03eea6f5ba7968706f1863c8bc4f8aaaf6a --- /dev/null +++ b/annotator/midas/__init__.py @@ -0,0 +1,38 @@ +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .api import MiDaSInference + + +class MidasDetector: + def __init__(self): + self.model = MiDaSInference(model_type="dpt_hybrid").cuda() + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + normal = np.stack([x, y, z], axis=2) + normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) + + return depth_image, normal_image diff --git a/annotator/midas/__pycache__/__init__.cpython-310.pyc b/annotator/midas/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94d1e12de8722702cdab4272d186ca48a515d18c Binary files /dev/null and b/annotator/midas/__pycache__/__init__.cpython-310.pyc differ diff --git a/annotator/midas/__pycache__/api.cpython-310.pyc b/annotator/midas/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a24ca8ec7916cbfb1899f3020490d90a6038a180 Binary files /dev/null and b/annotator/midas/__pycache__/api.cpython-310.pyc differ diff --git a/annotator/midas/api.py b/annotator/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab9f15bf96bbaffcee0e3e29fc9d3979d6c32e8 --- /dev/null +++ b/annotator/midas/api.py @@ -0,0 +1,169 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import os +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from .midas.dpt_depth import DPTDepthModel +from .midas.midas_net import MidasNet +from .midas.midas_net_custom import MidasNet_small +from .midas.transforms import Resize, NormalizeImage, PrepareForNet +from annotator.util import annotator_ckpts_path + + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + +remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + with torch.no_grad(): + prediction = self.model(x) + return prediction + diff --git a/annotator/midas/midas/__init__.py b/annotator/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/annotator/midas/midas/__pycache__/__init__.cpython-310.pyc b/annotator/midas/midas/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45650853feeb5f28179cde97c9bb9b6fc0f646f7 Binary files /dev/null and b/annotator/midas/midas/__pycache__/__init__.cpython-310.pyc differ diff --git a/annotator/midas/midas/__pycache__/base_model.cpython-310.pyc b/annotator/midas/midas/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1223c99ae712592d5d21903147bb41fa995b0a0f Binary files /dev/null and b/annotator/midas/midas/__pycache__/base_model.cpython-310.pyc differ diff --git a/annotator/midas/midas/__pycache__/blocks.cpython-310.pyc b/annotator/midas/midas/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d068a3c5ff6e412fd6c83e3314bc47d27f602ea7 Binary files /dev/null and b/annotator/midas/midas/__pycache__/blocks.cpython-310.pyc differ diff --git a/annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc b/annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5779c23707e01966d62af4e85283140c9e754c3 Binary files /dev/null and b/annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc differ diff --git a/annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc b/annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd53826e275318803912c2c170b4181d52ff316d Binary files /dev/null and b/annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc differ diff --git a/annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc b/annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..887e102d2f132dafb7105b090119cfd558bc8498 Binary files /dev/null and b/annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc differ diff --git a/annotator/midas/midas/__pycache__/transforms.cpython-310.pyc b/annotator/midas/midas/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8c56361e4a9afcb21a661e8e248193d586eda05 Binary files /dev/null and b/annotator/midas/midas/__pycache__/transforms.cpython-310.pyc differ diff --git a/annotator/midas/midas/__pycache__/vit.cpython-310.pyc b/annotator/midas/midas/__pycache__/vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c127a0f38f4ff078a74ec10fb2713129ae14230d Binary files /dev/null and b/annotator/midas/midas/__pycache__/vit.cpython-310.pyc differ diff --git a/annotator/midas/midas/base_model.py b/annotator/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/annotator/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/annotator/midas/midas/blocks.py b/annotator/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/annotator/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/annotator/midas/midas/dpt_depth.py b/annotator/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/annotator/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/annotator/midas/midas/midas_net.py b/annotator/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/annotator/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/annotator/midas/midas/midas_net_custom.py b/annotator/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/annotator/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/annotator/midas/midas/transforms.py b/annotator/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/annotator/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/annotator/midas/midas/vit.py b/annotator/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/annotator/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/annotator/midas/utils.py b/annotator/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/annotator/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/annotator/mlsd/__init__.py b/annotator/mlsd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..42af28c682e781b30f691f65a475b53c9f3adc8b --- /dev/null +++ b/annotator/mlsd/__init__.py @@ -0,0 +1,39 @@ +import cv2 +import numpy as np +import torch +import os + +from einops import rearrange +from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny +from .models.mbv2_mlsd_large import MobileV2_MLSD_Large +from .utils import pred_lines + +from annotator.util import annotator_ckpts_path + + +remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth" + + +class MLSDdetector: + def __init__(self): + model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth") + if not os.path.exists(model_path): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + model = MobileV2_MLSD_Large() + model.load_state_dict(torch.load(model_path), strict=True) + self.model = model.cuda().eval() + + def __call__(self, input_image, thr_v, thr_d): + assert input_image.ndim == 3 + img = input_image + img_output = np.zeros_like(img) + try: + with torch.no_grad(): + lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + for line in lines: + x_start, y_start, x_end, y_end = [int(val) for val in line] + cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) + except Exception as e: + pass + return img_output[:, :, 0] diff --git a/annotator/mlsd/utils.py b/annotator/mlsd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3cf9420a33a4abae27c48ac4b90938c7d63cc3 --- /dev/null +++ b/annotator/mlsd/utils.py @@ -0,0 +1,580 @@ +''' +modified by lihaoweicv +pytorch version +''' + +''' +M-LSD +Copyright 2021-present NAVER Corp. +Apache License v2.0 +''' + +import os +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): + ''' + tpMap: + center: tpMap[1, 0, :, :] + displacement: tpMap[1, 1:5, :, :] + ''' + b, c, h, w = tpMap.shape + assert b==1, 'only support bsize==1' + displacement = tpMap[:, 1:5, :, :][0] + center = tpMap[:, 0, :, :] + heat = torch.sigmoid(center) + hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) + keep = (hmax == heat).float() + heat = heat * keep + heat = heat.reshape(-1, ) + + scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) + yy = torch.floor_divide(indices, w).unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + ptss = torch.cat((yy, xx),dim=-1) + + ptss = ptss.detach().cpu().numpy() + scores = scores.detach().cpu().numpy() + displacement = displacement.detach().cpu().numpy() + displacement = displacement.transpose((1,2,0)) + return ptss, scores, displacement + + +def pred_lines(image, model, + input_shape=[512, 512], + score_thr=0.10, + dist_thr=20.0): + h, w, _ = image.shape + h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + + resized_image = resized_image.transpose((2,0,1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().cuda() + outputs = model(batch_image) + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] + end = vmap[:, :, 2:] + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + segments_list = [] + for center, score in zip(pts, pts_score): + y, x = center + distance = dist_map[y, x] + if score > score_thr and distance > dist_thr: + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + x_start = x + disp_x_start + y_start = y + disp_y_start + x_end = x + disp_x_end + y_end = y + disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + lines = 2 * np.array(segments_list) # 256 > 512 + lines[:, 0] = lines[:, 0] * w_ratio + lines[:, 1] = lines[:, 1] * h_ratio + lines[:, 2] = lines[:, 2] * w_ratio + lines[:, 3] = lines[:, 3] * h_ratio + + return lines + + +def pred_squares(image, + model, + input_shape=[512, 512], + params={'score': 0.06, + 'outside_ratio': 0.28, + 'inside_ratio': 0.45, + 'w_overlap': 0.0, + 'w_degree': 1.95, + 'w_length': 0.0, + 'w_area': 1.86, + 'w_center': 0.14}): + ''' + shape = [height, width] + ''' + h, w, _ = image.shape + original_shape = [h, w] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + resized_image = resized_image.transpose((2, 0, 1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().cuda() + outputs = model(batch_image) + + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] # (x, y) + end = vmap[:, :, 2:] # (x, y) + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + junc_list = [] + segments_list = [] + for junc, score in zip(pts, pts_score): + y, x = junc + distance = dist_map[y, x] + if score > params['score'] and distance > 20.0: + junc_list.append([x, y]) + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + d_arrow = 1.0 + x_start = x + d_arrow * disp_x_start + y_start = y + d_arrow * disp_y_start + x_end = x + d_arrow * disp_x_end + y_end = y + d_arrow * disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + segments = np.array(segments_list) + + ####### post processing for squares + # 1. get unique lines + point = np.array([[0, 0]]) + point = point[0] + start = segments[:, :2] + end = segments[:, 2:] + diff = start - end + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + + d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) + theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi + theta[theta < 0.0] += 180 + hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) + + d_quant = 1 + theta_quant = 2 + hough[:, 0] //= d_quant + hough[:, 1] //= theta_quant + _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) + + acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') + idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 + yx_indices = hough[indices, :].astype('int32') + acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts + idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices + + acc_map_np = acc_map + # acc_map = acc_map[None, :, :, None] + # + # ### fast suppression using tensorflow op + # acc_map = tf.constant(acc_map, dtype=tf.float32) + # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) + # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) + # flatten_acc_map = tf.reshape(acc_map, [1, -1]) + # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) + # _, h, w, _ = acc_map.shape + # y = tf.expand_dims(topk_indices // w, axis=-1) + # x = tf.expand_dims(topk_indices % w, axis=-1) + # yx = tf.concat([y, x], axis=-1) + + ### fast suppression using pytorch op + acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) + _,_, h, w = acc_map.shape + max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) + acc_map = acc_map * ( (acc_map == max_acc_map).float() ) + flatten_acc_map = acc_map.reshape([-1, ]) + + scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) + yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + yx = torch.cat((yy, xx), dim=-1) + + yx = yx.detach().cpu().numpy() + + topk_values = scores.detach().cpu().numpy() + indices = idx_map[yx[:, 0], yx[:, 1]] + basis = 5 // 2 + + merged_segments = [] + for yx_pt, max_indice, value in zip(yx, indices, topk_values): + y, x = yx_pt + if max_indice == -1 or value == 0: + continue + segment_list = [] + for y_offset in range(-basis, basis + 1): + for x_offset in range(-basis, basis + 1): + indice = idx_map[y + y_offset, x + x_offset] + cnt = int(acc_map_np[y + y_offset, x + x_offset]) + if indice != -1: + segment_list.append(segments[indice]) + if cnt > 1: + check_cnt = 1 + current_hough = hough[indice] + for new_indice, new_hough in enumerate(hough): + if (current_hough == new_hough).all() and indice != new_indice: + segment_list.append(segments[new_indice]) + check_cnt += 1 + if check_cnt == cnt: + break + group_segments = np.array(segment_list).reshape([-1, 2]) + sorted_group_segments = np.sort(group_segments, axis=0) + x_min, y_min = sorted_group_segments[0, :] + x_max, y_max = sorted_group_segments[-1, :] + + deg = theta[max_indice] + if deg >= 90: + merged_segments.append([x_min, y_max, x_max, y_min]) + else: + merged_segments.append([x_min, y_min, x_max, y_max]) + + # 2. get intersections + new_segments = np.array(merged_segments) # (x1, y1, x2, y2) + start = new_segments[:, :2] # (x1, y1) + end = new_segments[:, 2:] # (x2, y2) + new_centers = (start + end) / 2.0 + diff = start - end + dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) + + # ax + by = c + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + pre_det = a[:, None] * b[None, :] + det = pre_det - np.transpose(pre_det) + + pre_inter_y = a[:, None] * c[None, :] + inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) + pre_inter_x = c[:, None] * b[None, :] + inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) + inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') + + # 3. get corner information + # 3.1 get distance + ''' + dist_segments: + | dist(0), dist(1), dist(2), ...| + dist_inter_to_segment1: + | dist(inter,0), dist(inter,0), dist(inter,0), ... | + | dist(inter,1), dist(inter,1), dist(inter,1), ... | + ... + dist_inter_to_semgnet2: + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + ... + ''' + + dist_inter_to_segment1_start = np.sqrt( + np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment1_end = np.sqrt( + np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_start = np.sqrt( + np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_end = np.sqrt( + np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + + # sort ascending + dist_inter_to_segment1 = np.sort( + np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + dist_inter_to_segment2 = np.sort( + np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + + # 3.2 get degree + inter_to_start = new_centers[:, None, :] - inter_pts + deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi + deg_inter_to_start[deg_inter_to_start < 0.0] += 360 + inter_to_end = new_centers[None, :, :] - inter_pts + deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi + deg_inter_to_end[deg_inter_to_end < 0.0] += 360 + + ''' + B -- G + | | + C -- R + B : blue / G: green / C: cyan / R: red + + 0 -- 1 + | | + 3 -- 2 + ''' + # rename variables + deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end + # sort deg ascending + deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) + + deg_diff_map = np.abs(deg1_map - deg2_map) + # we only consider the smallest degree of intersect + deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] + + # define available degree range + deg_range = [60, 120] + + corner_dict = {corner_info: [] for corner_info in range(4)} + inter_points = [] + for i in range(inter_pts.shape[0]): + for j in range(i + 1, inter_pts.shape[1]): + # i, j > line index, always i < j + x, y = inter_pts[i, j, :] + deg1, deg2 = deg_sort[i, j, :] + deg_diff = deg_diff_map[i, j] + + check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] + + outside_ratio = params['outside_ratio'] # over ratio >>> drop it! + inside_ratio = params['inside_ratio'] # over ratio >>> drop it! + check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ + (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ + ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ + (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) + + if check_degree and check_distance: + corner_info = None + + if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ + (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): + corner_info, color_info = 0, 'blue' + elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): + corner_info, color_info = 1, 'green' + elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): + corner_info, color_info = 2, 'black' + elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ + (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): + corner_info, color_info = 3, 'cyan' + else: + corner_info, color_info = 4, 'red' # we don't use it + continue + + corner_dict[corner_info].append([x, y, i, j]) + inter_points.append([x, y]) + + square_list = [] + connect_list = [] + segments_list = [] + for corner0 in corner_dict[0]: + for corner1 in corner_dict[1]: + connect01 = False + for corner0_line in corner0[2:]: + if corner0_line in corner1[2:]: + connect01 = True + break + if connect01: + for corner2 in corner_dict[2]: + connect12 = False + for corner1_line in corner1[2:]: + if corner1_line in corner2[2:]: + connect12 = True + break + if connect12: + for corner3 in corner_dict[3]: + connect23 = False + for corner2_line in corner2[2:]: + if corner2_line in corner3[2:]: + connect23 = True + break + if connect23: + for corner3_line in corner3[2:]: + if corner3_line in corner0[2:]: + # SQUARE!!! + ''' + 0 -- 1 + | | + 3 -- 2 + square_list: + order: 0 > 1 > 2 > 3 + | x0, y0, x1, y1, x2, y2, x3, y3 | + | x0, y0, x1, y1, x2, y2, x3, y3 | + ... + connect_list: + order: 01 > 12 > 23 > 30 + | line_idx01, line_idx12, line_idx23, line_idx30 | + | line_idx01, line_idx12, line_idx23, line_idx30 | + ... + segments_list: + order: 0 > 1 > 2 > 3 + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + ... + ''' + square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) + connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) + segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) + + def check_outside_inside(segments_info, connect_idx): + # return 'outside or inside', min distance, cover_param, peri_param + if connect_idx == segments_info[0]: + check_dist_mat = dist_inter_to_segment1 + else: + check_dist_mat = dist_inter_to_segment2 + + i, j = segments_info + min_dist, max_dist = check_dist_mat[i, j, :] + connect_dist = dist_segments[connect_idx] + if max_dist > connect_dist: + return 'outside', min_dist, 0, 1 + else: + return 'inside', min_dist, -1, -1 + + top_square = None + + try: + map_size = input_shape[0] / 2 + squares = np.array(square_list).reshape([-1, 4, 2]) + score_array = [] + connect_array = np.array(connect_list) + segments_array = np.array(segments_list).reshape([-1, 4, 2]) + + # get degree of corners: + squares_rollup = np.roll(squares, 1, axis=1) + squares_rolldown = np.roll(squares, -1, axis=1) + vec1 = squares_rollup - squares + normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) + vec2 = squares_rolldown - squares + normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) + inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] + squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] + + # get square score + overlap_scores = [] + degree_scores = [] + length_scores = [] + + for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): + ''' + 0 -- 1 + | | + 3 -- 2 + + # segments: [4, 2] + # connects: [4] + ''' + + ###################################### OVERLAP SCORES + cover = 0 + perimeter = 0 + # check 0 > 1 > 2 > 3 + square_length = [] + + for start_idx in range(4): + end_idx = (start_idx + 1) % 4 + + connect_idx = connects[start_idx] # segment idx of segment01 + start_segments = segments[start_idx] + end_segments = segments[end_idx] + + start_point = square[start_idx] + end_point = square[end_idx] + + # check whether outside or inside + start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, + connect_idx) + end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) + + cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min + perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min + + square_length.append( + dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) + + overlap_scores.append(cover / perimeter) + ###################################### + ###################################### DEGREE SCORES + ''' + deg0 vs deg2 + deg1 vs deg3 + ''' + deg0, deg1, deg2, deg3 = degree + deg_ratio1 = deg0 / deg2 + if deg_ratio1 > 1.0: + deg_ratio1 = 1 / deg_ratio1 + deg_ratio2 = deg1 / deg3 + if deg_ratio2 > 1.0: + deg_ratio2 = 1 / deg_ratio2 + degree_scores.append((deg_ratio1 + deg_ratio2) / 2) + ###################################### + ###################################### LENGTH SCORES + ''' + len0 vs len2 + len1 vs len3 + ''' + len0, len1, len2, len3 = square_length + len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 + len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 + length_scores.append((len_ratio1 + len_ratio2) / 2) + + ###################################### + + overlap_scores = np.array(overlap_scores) + overlap_scores /= np.max(overlap_scores) + + degree_scores = np.array(degree_scores) + # degree_scores /= np.max(degree_scores) + + length_scores = np.array(length_scores) + + ###################################### AREA SCORES + area_scores = np.reshape(squares, [-1, 4, 2]) + area_x = area_scores[:, :, 0] + area_y = area_scores[:, :, 1] + correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] + area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) + area_scores = 0.5 * np.abs(area_scores + correction) + area_scores /= (map_size * map_size) # np.max(area_scores) + ###################################### + + ###################################### CENTER SCORES + centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] + # squares: [n, 4, 2] + square_centers = np.mean(squares, axis=1) # [n, 2] + center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) + center_scores = center2center / (map_size / np.sqrt(2.0)) + + ''' + score_w = [overlap, degree, area, center, length] + ''' + score_w = [0.0, 1.0, 10.0, 0.5, 1.0] + score_array = params['w_overlap'] * overlap_scores \ + + params['w_degree'] * degree_scores \ + + params['w_area'] * area_scores \ + - params['w_center'] * center_scores \ + + params['w_length'] * length_scores + + best_square = [] + + sorted_idx = np.argsort(score_array)[::-1] + score_array = score_array[sorted_idx] + squares = squares[sorted_idx] + + except Exception as e: + pass + + '''return list + merged_lines, squares, scores + ''' + + try: + new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] + new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] + except: + new_segments = [] + + try: + squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] + squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] + except: + squares = [] + score_array = [] + + try: + inter_points = np.array(inter_points) + inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] + inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] + except: + inter_points = [] + + return new_segments, squares, score_array, inter_points diff --git a/annotator/openpose/LICENSE b/annotator/openpose/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6f60b76d35fa1012809985780964a5068adce4fd --- /dev/null +++ b/annotator/openpose/LICENSE @@ -0,0 +1,108 @@ +OPENPOSE: MULTIPERSON KEYPOINT DETECTION +SOFTWARE LICENSE AGREEMENT +ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY + +BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. + +This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. + +RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: +Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, +non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). + +CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. + +COPYRIGHT: The Software is owned by Licensor and is protected by United +States copyright laws and applicable international treaties and/or conventions. + +PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. + +DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. + +BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. + +USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor. + +You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. + +ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. + +TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. + +The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. + +FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. + +DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. + +SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. + +EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. + +EXPORT REGULATION: Licensee agrees to comply with any and all applicable +U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. + +SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. + +NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. + +GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania. + +ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. + + + +************************************************************************ + +THIRD-PARTY SOFTWARE NOTICES AND INFORMATION + +This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. + +1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/) + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014-2017 The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014-2017, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** \ No newline at end of file diff --git a/annotator/openpose/__init__.py b/annotator/openpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33fa30c06fb1554d579cbfc70d5c410836b4b6c8 --- /dev/null +++ b/annotator/openpose/__init__.py @@ -0,0 +1,100 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import torch +import numpy as np +from . import util +from .body import Body +from .hand import Hand +from .face import Face +from annotator.util import annotator_ckpts_path + + +body_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/body_pose_model.pth" +hand_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/hand_pose_model.pth" +face_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/facenet.pth" + + +def draw_pose(pose, H, W, draw_body=True, draw_hand=True, draw_face=True): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + if draw_body: + canvas = util.draw_bodypose(canvas, candidate, subset) + + if draw_hand: + canvas = util.draw_handpose(canvas, hands) + + if draw_face: + canvas = util.draw_facepose(canvas, faces) + + return canvas + + +class OpenposeDetector: + def __init__(self): + body_modelpath = os.path.join(annotator_ckpts_path, "body_pose_model.pth") + hand_modelpath = os.path.join(annotator_ckpts_path, "hand_pose_model.pth") + face_modelpath = os.path.join(annotator_ckpts_path, "facenet.pth") + + if not os.path.exists(body_modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(body_model_path, model_dir=annotator_ckpts_path) + + if not os.path.exists(hand_modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(hand_model_path, model_dir=annotator_ckpts_path) + + if not os.path.exists(face_modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(face_model_path, model_dir=annotator_ckpts_path) + + self.body_estimation = Body(body_modelpath) + self.hand_estimation = Hand(hand_modelpath) + self.face_estimation = Face(face_modelpath) + + def __call__(self, oriImg, hand_and_face=False, hand=False, face=False, return_is_index=False): + oriImg = oriImg[:, :, ::-1].copy() + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.body_estimation(oriImg) + hands = [] + faces = [] + # if hand_and_face: + # Hand + hands_list = util.handDetect(candidate, subset, oriImg) + for x, y, w, is_left in hands_list: + peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32) + if peaks.ndim == 2 and peaks.shape[1] == 2: + peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) + peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) + hands.append(peaks.tolist()) + # Face + faces_list = util.faceDetect(candidate, subset, oriImg) + for x, y, w in faces_list: + heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :]) + peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32) + if peaks.ndim == 2 and peaks.shape[1] == 2: + peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) + peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) + faces.append(peaks.tolist()) + if candidate.ndim == 2 and candidate.shape[1] == 4: + candidate = candidate[:, :2] + candidate[:, 0] /= float(W) + candidate[:, 1] /= float(H) + bodies = dict(candidate=candidate.tolist(), subset=subset.tolist()) + pose = dict(bodies=bodies, hands=hands, faces=faces) + if return_is_index: + return pose + else: + return draw_pose(pose, H, W,draw_body=True, draw_hand=hand, draw_face=face) diff --git a/annotator/openpose/__pycache__/__init__.cpython-310.pyc b/annotator/openpose/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35bdd6c415fa270ec4465bb663962d5ff17bda49 Binary files /dev/null and b/annotator/openpose/__pycache__/__init__.cpython-310.pyc differ diff --git a/annotator/openpose/__pycache__/body.cpython-310.pyc b/annotator/openpose/__pycache__/body.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1219368bb73be2b175e82a8fba271937f14013e Binary files /dev/null and b/annotator/openpose/__pycache__/body.cpython-310.pyc differ diff --git a/annotator/openpose/__pycache__/face.cpython-310.pyc b/annotator/openpose/__pycache__/face.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65d0d2a1ad806d811dbb6cfd2a6a1b8aa5085489 Binary files /dev/null and b/annotator/openpose/__pycache__/face.cpython-310.pyc differ diff --git a/annotator/openpose/__pycache__/hand.cpython-310.pyc b/annotator/openpose/__pycache__/hand.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c37c858b9637f3b5b7aab5e773a7ca36a7d0f054 Binary files /dev/null and b/annotator/openpose/__pycache__/hand.cpython-310.pyc differ diff --git a/annotator/openpose/__pycache__/model.cpython-310.pyc b/annotator/openpose/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18b4424a74adb5e9e96e7a311603f9aed983a2f8 Binary files /dev/null and b/annotator/openpose/__pycache__/model.cpython-310.pyc differ diff --git a/annotator/openpose/__pycache__/util.cpython-310.pyc b/annotator/openpose/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fd671d024ccd9493e5f835ae5f877f6e3ebe8e0 Binary files /dev/null and b/annotator/openpose/__pycache__/util.cpython-310.pyc differ diff --git a/annotator/openpose/body.py b/annotator/openpose/body.py new file mode 100644 index 0000000000000000000000000000000000000000..025783a90f2075ab90067dedf18b375ffb52b32e --- /dev/null +++ b/annotator/openpose/body.py @@ -0,0 +1,219 @@ +import cv2 +import numpy as np +import math +import time +from scipy.ndimage.filters import gaussian_filter +import matplotlib.pyplot as plt +import matplotlib +import torch +from torchvision import transforms + +from . import util +from .model import bodypose_model + +class Body(object): + def __init__(self, model_path): + self.model = bodypose_model() + if torch.cuda.is_available(): + self.model = self.model.cuda() + print('cuda') + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def __call__(self, oriImg): + # scale_search = [0.5, 1.0, 1.5, 2.0] + scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre1 = 0.1 + thre2 = 0.05 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) + paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale) + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + if torch.cuda.is_available(): + data = data.cuda() + # data = data.permute([2, 0, 1]).unsqueeze(0).float() + with torch.no_grad(): + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) + Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() + + # extract outputs, resize, and remove padding + # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps + heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1])) + + # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs + paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs + paf = util.smart_resize_k(paf, fx=stride, fy=stride) + paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1])) + + heatmap_avg += heatmap_avg + heatmap / len(multiplier) + paf_avg += + paf / len(multiplier) + + all_peaks = [] + peak_counter = 0 + + for part in range(18): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) + peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + # find connection in the specified sequence, center 29 is in the position 15 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + # the middle joints heatmap correpondence + mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ + [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ + [55, 56], [37, 38], [45, 46]] + + connection_all = [] + special_k = [] + mid_num = 10 + + for k in range(len(mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] + candA = all_peaks[limbSeq[k][0] - 1] + candB = all_peaks[limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + indexA, indexB = limbSeq[k] + if (nA != 0 and nB != 0): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + norm = max(0.001, norm) + vec = np.divide(vec, norm) + + startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ + np.linspace(candA[i][1], candB[j][1], num=mid_num))) + + vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ + for I in range(len(startend))]) + vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ + for I in range(len(startend))]) + + score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( + 0.5 * oriImg.shape[0] / norm - 1, 0) + criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append( + [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) + + connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if (i not in connection[:, 3] and j not in connection[:, 4]): + connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) + if (len(connection) >= min(nA, nB)): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + # last number in each row is the total parts number of that person + # the second last number in each row is the score of the overall configuration + subset = -1 * np.ones((0, 20)) + candidate = np.array([item for sublist in all_peaks for item in sublist]) + + for k in range(len(mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + + # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts + # candidate: x, y, score, id + return candidate, subset + +if __name__ == "__main__": + body_estimation = Body('../model/body_pose_model.pth') + + test_image = '../images/ski.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + candidate, subset = body_estimation(oriImg) + canvas = util.draw_bodypose(oriImg, candidate, subset) + plt.imshow(canvas[:, :, [2, 1, 0]]) + plt.show() diff --git a/annotator/openpose/face.py b/annotator/openpose/face.py new file mode 100644 index 0000000000000000000000000000000000000000..24969e620186a5183da0cb9fae6ca8db62fe57fa --- /dev/null +++ b/annotator/openpose/face.py @@ -0,0 +1,363 @@ +import logging +import numpy as np +from torchvision.transforms import ToTensor, ToPILImage +import torch +import torch.nn.functional as F +import cv2 + +from . import util +from torch.nn import Conv2d, Module, ReLU, MaxPool2d, init + + +class FaceNet(Module): + """Model the cascading heatmaps. """ + def __init__(self): + super(FaceNet, self).__init__() + # cnn to make feature map + self.relu = ReLU() + self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2) + self.conv1_1 = Conv2d(in_channels=3, out_channels=64, + kernel_size=3, stride=1, padding=1) + self.conv1_2 = Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1) + self.conv2_1 = Conv2d( + in_channels=64, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv2_2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv3_1 = Conv2d( + in_channels=128, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_2 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_3 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_4 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv4_1 = Conv2d( + in_channels=256, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_3 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_4 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_1 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_3_CPM = Conv2d( + in_channels=512, out_channels=128, kernel_size=3, stride=1, + padding=1) + + # stage1 + self.conv6_1_CPM = Conv2d( + in_channels=128, out_channels=512, kernel_size=1, stride=1, + padding=0) + self.conv6_2_CPM = Conv2d( + in_channels=512, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage2 + self.Mconv1_stage2 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage2 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage3 + self.Mconv1_stage3 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage3 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage4 + self.Mconv1_stage4 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage4 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage5 + self.Mconv1_stage5 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage5 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage6 + self.Mconv1_stage6 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage6 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + for m in self.modules(): + if isinstance(m, Conv2d): + init.constant_(m.bias, 0) + + def forward(self, x): + """Return a list of heatmaps.""" + heatmaps = [] + + h = self.relu(self.conv1_1(x)) + h = self.relu(self.conv1_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv2_1(h)) + h = self.relu(self.conv2_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv3_1(h)) + h = self.relu(self.conv3_2(h)) + h = self.relu(self.conv3_3(h)) + h = self.relu(self.conv3_4(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv4_1(h)) + h = self.relu(self.conv4_2(h)) + h = self.relu(self.conv4_3(h)) + h = self.relu(self.conv4_4(h)) + h = self.relu(self.conv5_1(h)) + h = self.relu(self.conv5_2(h)) + h = self.relu(self.conv5_3_CPM(h)) + feature_map = h + + # stage1 + h = self.relu(self.conv6_1_CPM(h)) + h = self.conv6_2_CPM(h) + heatmaps.append(h) + + # stage2 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage2(h)) + h = self.relu(self.Mconv2_stage2(h)) + h = self.relu(self.Mconv3_stage2(h)) + h = self.relu(self.Mconv4_stage2(h)) + h = self.relu(self.Mconv5_stage2(h)) + h = self.relu(self.Mconv6_stage2(h)) + h = self.Mconv7_stage2(h) + heatmaps.append(h) + + # stage3 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage3(h)) + h = self.relu(self.Mconv2_stage3(h)) + h = self.relu(self.Mconv3_stage3(h)) + h = self.relu(self.Mconv4_stage3(h)) + h = self.relu(self.Mconv5_stage3(h)) + h = self.relu(self.Mconv6_stage3(h)) + h = self.Mconv7_stage3(h) + heatmaps.append(h) + + # stage4 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage4(h)) + h = self.relu(self.Mconv2_stage4(h)) + h = self.relu(self.Mconv3_stage4(h)) + h = self.relu(self.Mconv4_stage4(h)) + h = self.relu(self.Mconv5_stage4(h)) + h = self.relu(self.Mconv6_stage4(h)) + h = self.Mconv7_stage4(h) + heatmaps.append(h) + + # stage5 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage5(h)) + h = self.relu(self.Mconv2_stage5(h)) + h = self.relu(self.Mconv3_stage5(h)) + h = self.relu(self.Mconv4_stage5(h)) + h = self.relu(self.Mconv5_stage5(h)) + h = self.relu(self.Mconv6_stage5(h)) + h = self.Mconv7_stage5(h) + heatmaps.append(h) + + # stage6 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage6(h)) + h = self.relu(self.Mconv2_stage6(h)) + h = self.relu(self.Mconv3_stage6(h)) + h = self.relu(self.Mconv4_stage6(h)) + h = self.relu(self.Mconv5_stage6(h)) + h = self.relu(self.Mconv6_stage6(h)) + h = self.Mconv7_stage6(h) + heatmaps.append(h) + + return heatmaps + + +LOG = logging.getLogger(__name__) +TOTEN = ToTensor() +TOPIL = ToPILImage() + + +params = { + 'gaussian_sigma': 2.5, + 'inference_img_size': 736, # 368, 736, 1312 + 'heatmap_peak_thresh': 0.1, + 'crop_scale': 1.5, + 'line_indices': [ + [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], + [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], + [13, 14], [14, 15], [15, 16], + [17, 18], [18, 19], [19, 20], [20, 21], + [22, 23], [23, 24], [24, 25], [25, 26], + [27, 28], [28, 29], [29, 30], + [31, 32], [32, 33], [33, 34], [34, 35], + [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36], + [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42], + [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], + [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], + [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], + [66, 67], [67, 60] + ], +} + + +class Face(object): + """ + The OpenPose face landmark detector model. + + Args: + inference_size: set the size of the inference image size, suggested: + 368, 736, 1312, default 736 + gaussian_sigma: blur the heatmaps, default 2.5 + heatmap_peak_thresh: return landmark if over threshold, default 0.1 + + """ + def __init__(self, face_model_path, + inference_size=None, + gaussian_sigma=None, + heatmap_peak_thresh=None): + self.inference_size = inference_size or params["inference_img_size"] + self.sigma = gaussian_sigma or params['gaussian_sigma'] + self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"] + self.model = FaceNet() + self.model.load_state_dict(torch.load(face_model_path)) + if torch.cuda.is_available(): + self.model = self.model.cuda() + print('cuda') + self.model.eval() + + def __call__(self, face_img): + H, W, C = face_img.shape + + w_size = 384 + x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5 + + if torch.cuda.is_available(): + x_data = x_data.cuda() + + with torch.no_grad(): + hs = self.model(x_data[None, ...]) + heatmaps = F.interpolate( + hs[-1], + (H, W), + mode='bilinear', align_corners=True).cpu().numpy()[0] + return heatmaps + + def compute_peaks_from_heatmaps(self, heatmaps): + all_peaks = [] + for part in range(heatmaps.shape[0]): + map_ori = heatmaps[part].copy() + binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8) + + if np.sum(binary) == 0: + continue + + positions = np.where(binary > 0.5) + intensities = map_ori[positions] + mi = np.argmax(intensities) + y, x = positions[0][mi], positions[1][mi] + all_peaks.append([x, y]) + + return np.array(all_peaks) diff --git a/annotator/openpose/hand.py b/annotator/openpose/hand.py new file mode 100644 index 0000000000000000000000000000000000000000..0b268d3c9e3077b6155df01fe7b7e5aae377dbe4 --- /dev/null +++ b/annotator/openpose/hand.py @@ -0,0 +1,93 @@ +import cv2 +import json +import numpy as np +import math +import time +from scipy.ndimage.filters import gaussian_filter +import matplotlib.pyplot as plt +import matplotlib +import torch +from skimage.measure import label + +from .model import handpose_model +from . import util + +class Hand(object): + def __init__(self, model_path): + self.model = handpose_model() + if torch.cuda.is_available(): + self.model = self.model.cuda() + print('cuda') + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def __call__(self, oriImgRaw): + scale_search = [0.5, 1.0, 1.5, 2.0] + # scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre = 0.05 + multiplier = [x * boxsize for x in scale_search] + + wsize = 128 + heatmap_avg = np.zeros((wsize, wsize, 22)) + + Hr, Wr, Cr = oriImgRaw.shape + + oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize(oriImg, (scale, scale)) + + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + if torch.cuda.is_available(): + data = data.cuda() + + with torch.no_grad(): + output = self.model(data).cpu().numpy() + + # extract outputs, resize, and remove padding + heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (wsize, wsize)) + + heatmap_avg += heatmap / len(multiplier) + + all_peaks = [] + for part in range(21): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) + + if np.sum(binary) == 0: + all_peaks.append([0, 0]) + continue + label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) + max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 + label_img[label_img != max_index] = 0 + map_ori[label_img == 0] = 0 + + y, x = util.npmax(map_ori) + y = int(float(y) * float(Hr) / float(wsize)) + x = int(float(x) * float(Wr) / float(wsize)) + all_peaks.append([x, y]) + return np.array(all_peaks) + +if __name__ == "__main__": + hand_estimation = Hand('../model/hand_pose_model.pth') + + # test_image = '../images/hand.jpg' + test_image = '../images/hand.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + peaks = hand_estimation(oriImg) + canvas = util.draw_handpose(oriImg, peaks, True) + cv2.imshow('', canvas) + cv2.waitKey(0) \ No newline at end of file diff --git a/annotator/openpose/model.py b/annotator/openpose/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5dfc80de827a17beccb9b0f3f7588545be78c9de --- /dev/null +++ b/annotator/openpose/model.py @@ -0,0 +1,219 @@ +import torch +from collections import OrderedDict + +import torch +import torch.nn as nn + +def make_layers(block, no_relu_layers): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], + padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], + kernel_size=v[2], stride=v[3], + padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) + + return nn.Sequential(OrderedDict(layers)) + +class bodypose_model(nn.Module): + def __init__(self): + super(bodypose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] + blocks = {} + block0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), + ('conv4_4_CPM', [256, 128, 3, 1, 1]) + ]) + + + # Stage 1 + block1_1 = OrderedDict([ + ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) + ]) + + block1_2 = OrderedDict([ + ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) + ]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = make_layers(block0, no_relu_layers) + + # Stages 2 - 6 + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([ + ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) + ]) + + blocks['block%d_2' % i] = OrderedDict([ + ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + + def forward(self, x): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = torch.cat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = torch.cat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = torch.cat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = torch.cat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = torch.cat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 + +class handpose_model(nn.Module): + def __init__(self): + super(handpose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ + 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] + # stage 1 + block1_0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3', [512, 512, 3, 1, 1]), + ('conv4_4', [512, 512, 3, 1, 1]), + ('conv5_1', [512, 512, 3, 1, 1]), + ('conv5_2', [512, 512, 3, 1, 1]), + ('conv5_3_CPM', [512, 128, 3, 1, 1]) + ]) + + block1_1 = OrderedDict([ + ('conv6_1_CPM', [128, 512, 1, 1, 0]), + ('conv6_2_CPM', [512, 22, 1, 1, 0]) + ]) + + blocks = {} + blocks['block1_0'] = block1_0 + blocks['block1_1'] = block1_1 + + # stage 2-6 + for i in range(2, 7): + blocks['block%d' % i] = OrderedDict([ + ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), + ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_0 = blocks['block1_0'] + self.model1_1 = blocks['block1_1'] + self.model2 = blocks['block2'] + self.model3 = blocks['block3'] + self.model4 = blocks['block4'] + self.model5 = blocks['block5'] + self.model6 = blocks['block6'] + + def forward(self, x): + out1_0 = self.model1_0(x) + out1_1 = self.model1_1(out1_0) + concat_stage2 = torch.cat([out1_1, out1_0], 1) + out_stage2 = self.model2(concat_stage2) + concat_stage3 = torch.cat([out_stage2, out1_0], 1) + out_stage3 = self.model3(concat_stage3) + concat_stage4 = torch.cat([out_stage3, out1_0], 1) + out_stage4 = self.model4(concat_stage4) + concat_stage5 = torch.cat([out_stage4, out1_0], 1) + out_stage5 = self.model5(concat_stage5) + concat_stage6 = torch.cat([out_stage5, out1_0], 1) + out_stage6 = self.model6(concat_stage6) + return out_stage6 + + diff --git a/annotator/openpose/util.py b/annotator/openpose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..73d7d0153b38d143eb8090e07a9784a274b619ed --- /dev/null +++ b/annotator/openpose/util.py @@ -0,0 +1,297 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for peaks in all_hand_peaks: + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/annotator/uniformer/__init__.py b/annotator/uniformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6be429542e4908c2b7648e7ee7c9c5f8253e7c94 --- /dev/null +++ b/annotator/uniformer/__init__.py @@ -0,0 +1,23 @@ +import os + +from annotator.uniformer.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot +from annotator.uniformer.mmseg.core.evaluation import get_palette +from annotator.util import annotator_ckpts_path + + +checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth" + + +class UniformerDetector: + def __init__(self): + modelpath = os.path.join(annotator_ckpts_path, "upernet_global_small.pth") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path) + config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py") + self.model = init_segmentor(config_file, modelpath).cuda() + + def __call__(self, img): + result = inference_segmentor(self.model, img) + res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1) + return res_img diff --git a/annotator/uniformer/configs/_base_/datasets/ade20k.py b/annotator/uniformer/configs/_base_/datasets/ade20k.py new file mode 100644 index 0000000000000000000000000000000000000000..efc8b4bb20c981f3db6df7eb52b3dc0744c94cc0 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/ade20k.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'ADE20KDataset' +data_root = 'data/ade/ADEChallengeData2016' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/chase_db1.py b/annotator/uniformer/configs/_base_/datasets/chase_db1.py new file mode 100644 index 0000000000000000000000000000000000000000..298594ea925f87f22b37094a2ec50e370aec96a0 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/chase_db1.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'ChaseDB1Dataset' +data_root = 'data/CHASE_DB1' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (960, 999) +crop_size = (128, 128) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/cityscapes.py b/annotator/uniformer/configs/_base_/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..f21867c63e1835f6fceb61f066e802fd8fd2a735 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/cityscapes.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'CityscapesDataset' +data_root = 'data/cityscapes/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 1024) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/train', + ann_dir='gtFine/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py b/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py new file mode 100644 index 0000000000000000000000000000000000000000..336c7b254fe392b4703039fec86a83acdbd2e1a5 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py @@ -0,0 +1,35 @@ +_base_ = './cityscapes.py' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (769, 769) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2049, 1025), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/drive.py b/annotator/uniformer/configs/_base_/datasets/drive.py new file mode 100644 index 0000000000000000000000000000000000000000..06e8ff606e0d2a4514ec8b7d2c6c436a32efcbf4 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/drive.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'DRIVEDataset' +data_root = 'data/DRIVE' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (584, 565) +crop_size = (64, 64) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/hrf.py b/annotator/uniformer/configs/_base_/datasets/hrf.py new file mode 100644 index 0000000000000000000000000000000000000000..242d790eb1b83e75cf6b7eaa7a35c674099311ad --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/hrf.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'HRFDataset' +data_root = 'data/HRF' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (2336, 3504) +crop_size = (256, 256) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_context.py b/annotator/uniformer/configs/_base_/datasets/pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..ff65bad1b86d7e3a5980bb5b9fc55798dc8df5f4 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/pascal_context.py @@ -0,0 +1,60 @@ +# dataset settings +dataset_type = 'PascalContextDataset' +data_root = 'data/VOCdevkit/VOC2010/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +img_scale = (520, 520) +crop_size = (480, 480) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py b/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py new file mode 100644 index 0000000000000000000000000000000000000000..37585abab89834b95cd5bdd993b994fca1db65f6 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py @@ -0,0 +1,60 @@ +# dataset settings +dataset_type = 'PascalContextDataset59' +data_root = 'data/VOCdevkit/VOC2010/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +img_scale = (520, 520) +crop_size = (480, 480) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py b/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1d42d0c5781f56dc177d860d856bb34adce555 --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py @@ -0,0 +1,57 @@ +# dataset settings +dataset_type = 'PascalVOCDataset' +data_root = 'data/VOCdevkit/VOC2012' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/val.txt', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py b/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..3f23b6717d53ad29f02dd15046802a2631a5076b --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py @@ -0,0 +1,9 @@ +_base_ = './pascal_voc12.py' +# dataset settings +data = dict( + train=dict( + ann_dir=['SegmentationClass', 'SegmentationClassAug'], + split=[ + 'ImageSets/Segmentation/train.txt', + 'ImageSets/Segmentation/aug.txt' + ])) diff --git a/annotator/uniformer/configs/_base_/datasets/stare.py b/annotator/uniformer/configs/_base_/datasets/stare.py new file mode 100644 index 0000000000000000000000000000000000000000..3f71b25488cc11a6b4d582ac52b5a24e1ad1cf8e --- /dev/null +++ b/annotator/uniformer/configs/_base_/datasets/stare.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'STAREDataset' +data_root = 'data/STARE' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (605, 700) +crop_size = (128, 128) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/annotator/uniformer/configs/_base_/default_runtime.py b/annotator/uniformer/configs/_base_/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..b564cc4e7e7d9a67dacaaddecb100e4d8f5c005b --- /dev/null +++ b/annotator/uniformer/configs/_base_/default_runtime.py @@ -0,0 +1,14 @@ +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook', by_epoch=False), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_160k.py b/annotator/uniformer/configs/_base_/schedules/schedule_160k.py new file mode 100644 index 0000000000000000000000000000000000000000..52603890b10f25faf8eec9f9e5a4468fae09b811 --- /dev/null +++ b/annotator/uniformer/configs/_base_/schedules/schedule_160k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=160000) +checkpoint_config = dict(by_epoch=False, interval=16000) +evaluation = dict(interval=16000, metric='mIoU') diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_20k.py b/annotator/uniformer/configs/_base_/schedules/schedule_20k.py new file mode 100644 index 0000000000000000000000000000000000000000..bf780a1b6f6521833c6a5859675147824efa599d --- /dev/null +++ b/annotator/uniformer/configs/_base_/schedules/schedule_20k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=20000) +checkpoint_config = dict(by_epoch=False, interval=2000) +evaluation = dict(interval=2000, metric='mIoU') diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_40k.py b/annotator/uniformer/configs/_base_/schedules/schedule_40k.py new file mode 100644 index 0000000000000000000000000000000000000000..cdbf841abcb26eed87bf76ab816aff4bae0630ee --- /dev/null +++ b/annotator/uniformer/configs/_base_/schedules/schedule_40k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=40000) +checkpoint_config = dict(by_epoch=False, interval=4000) +evaluation = dict(interval=4000, metric='mIoU') diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_80k.py b/annotator/uniformer/configs/_base_/schedules/schedule_80k.py new file mode 100644 index 0000000000000000000000000000000000000000..c190cee6bdc7922b688ea75dc8f152fa15c24617 --- /dev/null +++ b/annotator/uniformer/configs/_base_/schedules/schedule_80k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=80000) +checkpoint_config = dict(by_epoch=False, interval=8000) +evaluation = dict(interval=8000, metric='mIoU') diff --git a/annotator/uniformer/mmcv/__init__.py b/annotator/uniformer/mmcv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..210a2989138380559f23045b568d0fbbeb918c03 --- /dev/null +++ b/annotator/uniformer/mmcv/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# flake8: noqa +from .arraymisc import * +from .fileio import * +from .image import * +from .utils import * +from .version import * +from .video import * +from .visualization import * + +# The following modules are not imported to this level, so mmcv may be used +# without PyTorch. +# - runner +# - parallel +# - op diff --git a/annotator/uniformer/mmcv/arraymisc/__init__.py b/annotator/uniformer/mmcv/arraymisc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4700d6139ae3d604ff6e542468cce4200c020c --- /dev/null +++ b/annotator/uniformer/mmcv/arraymisc/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .quantization import dequantize, quantize + +__all__ = ['quantize', 'dequantize'] diff --git a/annotator/uniformer/mmcv/arraymisc/quantization.py b/annotator/uniformer/mmcv/arraymisc/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..8e47a3545780cf071a1ef8195efb0b7b662c8186 --- /dev/null +++ b/annotator/uniformer/mmcv/arraymisc/quantization.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum( + np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - + min_val) / levels + min_val + + return dequantized_arr diff --git a/annotator/uniformer/mmcv/cnn/__init__.py b/annotator/uniformer/mmcv/cnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7246c897430f0cc7ce12719ad8608824fc734446 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .alexnet import AlexNet +# yapf: disable +from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, + PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS, + ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule, + ConvTranspose2d, ConvTranspose3d, ConvWS2d, + DepthwiseSeparableConvModule, GeneralizedAttention, + HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d, + NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish, + build_activation_layer, build_conv_layer, + build_norm_layer, build_padding_layer, build_plugin_layer, + build_upsample_layer, conv_ws_2d, is_norm) +from .builder import MODELS, build_model_from_cfg +# yapf: enable +from .resnet import ResNet, make_res_layer +from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit, + NormalInit, PretrainedInit, TruncNormalInit, UniformInit, + XavierInit, bias_init_with_prob, caffe2_xavier_init, + constant_init, fuse_conv_bn, get_model_complexity_info, + initialize, kaiming_init, normal_init, trunc_normal_init, + uniform_init, xavier_init) +from .vgg import VGG, make_vgg_layer + +__all__ = [ + 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer', + 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init', + 'uniform_init', 'kaiming_init', 'caffe2_xavier_init', + 'bias_init_with_prob', 'ConvModule', 'build_activation_layer', + 'build_conv_layer', 'build_norm_layer', 'build_padding_layer', + 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d', + 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish', + 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', + 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', + 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d', + 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d', + 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', + 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit', + 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', + 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg' +] diff --git a/annotator/uniformer/mmcv/cnn/alexnet.py b/annotator/uniformer/mmcv/cnn/alexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..89e36b8c7851f895d9ae7f07149f0e707456aab0 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/alexnet.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +import torch.nn as nn + + +class AlexNet(nn.Module): + """AlexNet backbone. + + Args: + num_classes (int): number of classes for classification. + """ + + def __init__(self, num_classes=-1): + super(AlexNet, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + from ..runner import load_checkpoint + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + # use default initializer + pass + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + + x = self.features(x) + if self.num_classes > 0: + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + + return x diff --git a/annotator/uniformer/mmcv/cnn/bricks/__init__.py b/annotator/uniformer/mmcv/cnn/bricks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f33124ed23fc6f27119a37bcb5ab004d3572be0 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .activation import build_activation_layer +from .context_block import ContextBlock +from .conv import build_conv_layer +from .conv2d_adaptive_padding import Conv2dAdaptivePadding +from .conv_module import ConvModule +from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d +from .depthwise_separable_conv_module import DepthwiseSeparableConvModule +from .drop import Dropout, DropPath +from .generalized_attention import GeneralizedAttention +from .hsigmoid import HSigmoid +from .hswish import HSwish +from .non_local import NonLocal1d, NonLocal2d, NonLocal3d +from .norm import build_norm_layer, is_norm +from .padding import build_padding_layer +from .plugin import build_plugin_layer +from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, + PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS) +from .scale import Scale +from .swish import Swish +from .upsample import build_upsample_layer +from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, + Linear, MaxPool2d, MaxPool3d) + +__all__ = [ + 'ConvModule', 'build_activation_layer', 'build_conv_layer', + 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer', + 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d', + 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention', + 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', + 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d', + 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear', + 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', + 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath' +] diff --git a/annotator/uniformer/mmcv/cnn/bricks/activation.py b/annotator/uniformer/mmcv/cnn/bricks/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..cab2712287d5ef7be2f079dcb54a94b96394eab5 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/activation.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from annotator.uniformer.mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version +from .registry import ACTIVATION_LAYERS + +for module in [ + nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU, + nn.Sigmoid, nn.Tanh +]: + ACTIVATION_LAYERS.register_module(module=module) + + +@ACTIVATION_LAYERS.register_module(name='Clip') +@ACTIVATION_LAYERS.register_module() +class Clamp(nn.Module): + """Clamp activation layer. + + This activation function is to clamp the feature map value within + :math:`[min, max]`. More details can be found in ``torch.clamp()``. + + Args: + min (Number | optional): Lower-bound of the range to be clamped to. + Default to -1. + max (Number | optional): Upper-bound of the range to be clamped to. + Default to 1. + """ + + def __init__(self, min=-1., max=1.): + super(Clamp, self).__init__() + self.min = min + self.max = max + + def forward(self, x): + """Forward function. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: Clamped tensor. + """ + return torch.clamp(x, min=self.min, max=self.max) + + +class GELU(nn.Module): + r"""Applies the Gaussian Error Linear Units function: + + .. math:: + \text{GELU}(x) = x * \Phi(x) + where :math:`\Phi(x)` is the Cumulative Distribution Function for + Gaussian Distribution. + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: scripts/activation_images/GELU.png + + Examples:: + + >>> m = nn.GELU() + >>> input = torch.randn(2) + >>> output = m(input) + """ + + def forward(self, input): + return F.gelu(input) + + +if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.4')): + ACTIVATION_LAYERS.register_module(module=GELU) +else: + ACTIVATION_LAYERS.register_module(module=nn.GELU) + + +def build_activation_layer(cfg): + """Build activation layer. + + Args: + cfg (dict): The activation layer config, which should contain: + - type (str): Layer type. + - layer args: Args needed to instantiate an activation layer. + + Returns: + nn.Module: Created activation layer. + """ + return build_from_cfg(cfg, ACTIVATION_LAYERS) diff --git a/annotator/uniformer/mmcv/cnn/bricks/context_block.py b/annotator/uniformer/mmcv/cnn/bricks/context_block.py new file mode 100644 index 0000000000000000000000000000000000000000..d60fdb904c749ce3b251510dff3cc63cea70d42e --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/context_block.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from ..utils import constant_init, kaiming_init +from .registry import PLUGIN_LAYERS + + +def last_zero_init(m): + if isinstance(m, nn.Sequential): + constant_init(m[-1], val=0) + else: + constant_init(m, val=0) + + +@PLUGIN_LAYERS.register_module() +class ContextBlock(nn.Module): + """ContextBlock module in GCNet. + + See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond' + (https://arxiv.org/abs/1904.11492) for details. + + Args: + in_channels (int): Channels of the input feature map. + ratio (float): Ratio of channels of transform bottleneck + pooling_type (str): Pooling method for context modeling. + Options are 'att' and 'avg', stand for attention pooling and + average pooling respectively. Default: 'att'. + fusion_types (Sequence[str]): Fusion method for feature fusion, + Options are 'channels_add', 'channel_mul', stand for channelwise + addition and multiplication respectively. Default: ('channel_add',) + """ + + _abbr_ = 'context_block' + + def __init__(self, + in_channels, + ratio, + pooling_type='att', + fusion_types=('channel_add', )): + super(ContextBlock, self).__init__() + assert pooling_type in ['avg', 'att'] + assert isinstance(fusion_types, (list, tuple)) + valid_fusion_types = ['channel_add', 'channel_mul'] + assert all([f in valid_fusion_types for f in fusion_types]) + assert len(fusion_types) > 0, 'at least one fusion should be used' + self.in_channels = in_channels + self.ratio = ratio + self.planes = int(in_channels * ratio) + self.pooling_type = pooling_type + self.fusion_types = fusion_types + if pooling_type == 'att': + self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1) + self.softmax = nn.Softmax(dim=2) + else: + self.avg_pool = nn.AdaptiveAvgPool2d(1) + if 'channel_add' in fusion_types: + self.channel_add_conv = nn.Sequential( + nn.Conv2d(self.in_channels, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(inplace=True), # yapf: disable + nn.Conv2d(self.planes, self.in_channels, kernel_size=1)) + else: + self.channel_add_conv = None + if 'channel_mul' in fusion_types: + self.channel_mul_conv = nn.Sequential( + nn.Conv2d(self.in_channels, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(inplace=True), # yapf: disable + nn.Conv2d(self.planes, self.in_channels, kernel_size=1)) + else: + self.channel_mul_conv = None + self.reset_parameters() + + def reset_parameters(self): + if self.pooling_type == 'att': + kaiming_init(self.conv_mask, mode='fan_in') + self.conv_mask.inited = True + + if self.channel_add_conv is not None: + last_zero_init(self.channel_add_conv) + if self.channel_mul_conv is not None: + last_zero_init(self.channel_mul_conv) + + def spatial_pool(self, x): + batch, channel, height, width = x.size() + if self.pooling_type == 'att': + input_x = x + # [N, C, H * W] + input_x = input_x.view(batch, channel, height * width) + # [N, 1, C, H * W] + input_x = input_x.unsqueeze(1) + # [N, 1, H, W] + context_mask = self.conv_mask(x) + # [N, 1, H * W] + context_mask = context_mask.view(batch, 1, height * width) + # [N, 1, H * W] + context_mask = self.softmax(context_mask) + # [N, 1, H * W, 1] + context_mask = context_mask.unsqueeze(-1) + # [N, 1, C, 1] + context = torch.matmul(input_x, context_mask) + # [N, C, 1, 1] + context = context.view(batch, channel, 1, 1) + else: + # [N, C, 1, 1] + context = self.avg_pool(x) + + return context + + def forward(self, x): + # [N, C, 1, 1] + context = self.spatial_pool(x) + + out = x + if self.channel_mul_conv is not None: + # [N, C, 1, 1] + channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) + out = out * channel_mul_term + if self.channel_add_conv is not None: + # [N, C, 1, 1] + channel_add_term = self.channel_add_conv(context) + out = out + channel_add_term + + return out diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv.py b/annotator/uniformer/mmcv/cnn/bricks/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..cf54491997a48ac3e7fadc4183ab7bf3e831024c --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/conv.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn + +from .registry import CONV_LAYERS + +CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d) +CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d) +CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d) +CONV_LAYERS.register_module('Conv', module=nn.Conv2d) + + +def build_conv_layer(cfg, *args, **kwargs): + """Build convolution layer. + + Args: + cfg (None or dict): The conv layer config, which should contain: + - type (str): Layer type. + - layer args: Args needed to instantiate an conv layer. + args (argument list): Arguments passed to the `__init__` + method of the corresponding conv layer. + kwargs (keyword arguments): Keyword arguments passed to the `__init__` + method of the corresponding conv layer. + + Returns: + nn.Module: Created conv layer. + """ + if cfg is None: + cfg_ = dict(type='Conv2d') + else: + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in CONV_LAYERS: + raise KeyError(f'Unrecognized norm type {layer_type}') + else: + conv_layer = CONV_LAYERS.get(layer_type) + + layer = conv_layer(*args, **kwargs, **cfg_) + + return layer diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py b/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py new file mode 100644 index 0000000000000000000000000000000000000000..b45e758ac6cf8dfb0382d072fe09125bc7e9b888 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from torch import nn +from torch.nn import functional as F + +from .registry import CONV_LAYERS + + +@CONV_LAYERS.register_module() +class Conv2dAdaptivePadding(nn.Conv2d): + """Implementation of 2D convolution in tensorflow with `padding` as "same", + which applies padding to input (if needed) so that input image gets fully + covered by filter and stride you specified. For stride 1, this will ensure + that output image size is same as input. For stride of 2, output dimensions + will be half, for example. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. + Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, + dilation, groups, bias) + + def forward(self, x): + img_h, img_w = x.size()[-2:] + kernel_h, kernel_w = self.weight.size()[-2:] + stride_h, stride_w = self.stride + output_h = math.ceil(img_h / stride_h) + output_w = math.ceil(img_w / stride_w) + pad_h = ( + max((output_h - 1) * self.stride[0] + + (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0)) + pad_w = ( + max((output_w - 1) * self.stride[1] + + (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0)) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv_module.py b/annotator/uniformer/mmcv/cnn/bricks/conv_module.py new file mode 100644 index 0000000000000000000000000000000000000000..e60e7e62245071c77b652093fddebff3948d7c3e --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/conv_module.py @@ -0,0 +1,206 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn + +from annotator.uniformer.mmcv.utils import _BatchNorm, _InstanceNorm +from ..utils import constant_init, kaiming_init +from .activation import build_activation_layer +from .conv import build_conv_layer +from .norm import build_norm_layer +from .padding import build_padding_layer +from .registry import PLUGIN_LAYERS + + +@PLUGIN_LAYERS.register_module() +class ConvModule(nn.Module): + """A conv block that bundles conv/norm/activation layers. + + This block simplifies the usage of convolution layers, which are commonly + used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + It is based upon three build methods: `build_conv_layer()`, + `build_norm_layer()` and `build_activation_layer()`. + + Besides, we add some additional features in this module. + 1. Automatically set `bias` of the conv layer. + 2. Spectral norm is supported. + 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only + supports zero and circular padding, and we add "reflect" padding mode. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise + False. Default: "auto". + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + _abbr_ = 'conv_block' + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias='auto', + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + inplace=True, + with_spectral_norm=False, + padding_mode='zeros', + order=('conv', 'norm', 'act')): + super(ConvModule, self).__init__() + assert conv_cfg is None or isinstance(conv_cfg, dict) + assert norm_cfg is None or isinstance(norm_cfg, dict) + assert act_cfg is None or isinstance(act_cfg, dict) + official_padding_mode = ['zeros', 'circular'] + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.inplace = inplace + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == set(['conv', 'norm', 'act']) + + self.with_norm = norm_cfg is not None + self.with_activation = act_cfg is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == 'auto': + bias = not self.with_norm + self.with_bias = bias + + if self.with_explicit_padding: + pad_cfg = dict(type=padding_mode) + self.padding_layer = build_padding_layer(pad_cfg, padding) + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + # build convolution layer + self.conv = build_conv_layer( + conv_cfg, + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias) + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index('norm') > order.index('conv'): + norm_channels = out_channels + else: + norm_channels = in_channels + self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) + self.add_module(self.norm_name, norm) + if self.with_bias: + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn( + 'Unnecessary conv bias before batch/instance norm') + else: + self.norm_name = None + + # build activation layer + if self.with_activation: + act_cfg_ = act_cfg.copy() + # nn.Tanh has no 'inplace' argument + if act_cfg_['type'] not in [ + 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish' + ]: + act_cfg_.setdefault('inplace', inplace) + self.activate = build_activation_layer(act_cfg_) + + # Use msra init by default + self.init_weights() + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def init_weights(self): + # 1. It is mainly for customized conv layers with their own + # initialization manners by calling their own ``init_weights()``, + # and we do not want ConvModule to override the initialization. + # 2. For customized conv layers without their own initialization + # manners (that is, they don't have their own ``init_weights()``) + # and PyTorch's conv layers, they will be initialized by + # this method with default ``kaiming_init``. + # Note: For PyTorch's conv layers, they will be overwritten by our + # initialization implementation using default ``kaiming_init``. + if not hasattr(self.conv, 'init_weights'): + if self.with_activation and self.act_cfg['type'] == 'LeakyReLU': + nonlinearity = 'leaky_relu' + a = self.act_cfg.get('negative_slope', 0.01) + else: + nonlinearity = 'relu' + a = 0 + kaiming_init(self.conv, a=a, nonlinearity=nonlinearity) + if self.with_norm: + constant_init(self.norm, 1, bias=0) + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == 'conv': + if self.with_explicit_padding: + x = self.padding_layer(x) + x = self.conv(x) + elif layer == 'norm' and norm and self.with_norm: + x = self.norm(x) + elif layer == 'act' and activate and self.with_activation: + x = self.activate(x) + return x diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py b/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py new file mode 100644 index 0000000000000000000000000000000000000000..a3941e27874993418b3b5708d5a7485f175ff9c8 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .registry import CONV_LAYERS + + +def conv_ws_2d(input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + eps=1e-5): + c_in = weight.size(0) + weight_flat = weight.view(c_in, -1) + mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) + std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1) + weight = (weight - mean) / (std + eps) + return F.conv2d(input, weight, bias, stride, padding, dilation, groups) + + +@CONV_LAYERS.register_module('ConvWS') +class ConvWS2d(nn.Conv2d): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + eps=1e-5): + super(ConvWS2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + self.eps = eps + + def forward(self, x): + return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups, self.eps) + + +@CONV_LAYERS.register_module(name='ConvAWS') +class ConvAWS2d(nn.Conv2d): + """AWS (Adaptive Weight Standardization) + + This is a variant of Weight Standardization + (https://arxiv.org/pdf/1903.10520.pdf) + It is used in DetectoRS to avoid NaN + (https://arxiv.org/pdf/2006.02334.pdf) + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the conv kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. + Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If set True, adds a learnable bias to the + output. Default: True + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + self.register_buffer('weight_gamma', + torch.ones(self.out_channels, 1, 1, 1)) + self.register_buffer('weight_beta', + torch.zeros(self.out_channels, 1, 1, 1)) + + def _get_weight(self, weight): + weight_flat = weight.view(weight.size(0), -1) + mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) + std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) + weight = (weight - mean) / std + weight = self.weight_gamma * weight + self.weight_beta + return weight + + def forward(self, x): + weight = self._get_weight(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Override default load function. + + AWS overrides the function _load_from_state_dict to recover + weight_gamma and weight_beta if they are missing. If weight_gamma and + weight_beta are found in the checkpoint, this function will return + after super()._load_from_state_dict. Otherwise, it will compute the + mean and std of the pretrained weights and store them in weight_beta + and weight_gamma. + """ + + self.weight_gamma.data.fill_(-1) + local_missing_keys = [] + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, local_missing_keys, + unexpected_keys, error_msgs) + if self.weight_gamma.data.mean() > 0: + for k in local_missing_keys: + missing_keys.append(k) + return + weight = self.weight.data + weight_flat = weight.view(weight.size(0), -1) + mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) + std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) + self.weight_beta.data.copy_(mean) + self.weight_gamma.data.copy_(std) + missing_gamma_beta = [ + k for k in local_missing_keys + if k.endswith('weight_gamma') or k.endswith('weight_beta') + ] + for k in missing_gamma_beta: + local_missing_keys.remove(k) + for k in local_missing_keys: + missing_keys.append(k) diff --git a/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py b/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py new file mode 100644 index 0000000000000000000000000000000000000000..722d5d8d71f75486e2db3008907c4eadfca41d63 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from .conv_module import ConvModule + + +class DepthwiseSeparableConvModule(nn.Module): + """Depthwise separable convolution module. + + See https://arxiv.org/pdf/1704.04861.pdf for details. + + This module can replace a ConvModule with the conv block replaced by two + conv block: depthwise conv block and pointwise conv block. The depthwise + conv block contains depthwise-conv/norm/activation layers. The pointwise + conv block contains pointwise-conv/norm/activation layers. It should be + noted that there will be norm/activation layer in the depthwise conv block + if `norm_cfg` and `act_cfg` are specified. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. Default: 1. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. Default: 0. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. Default: 1. + norm_cfg (dict): Default norm config for both depthwise ConvModule and + pointwise ConvModule. Default: None. + act_cfg (dict): Default activation config for both depthwise ConvModule + and pointwise ConvModule. Default: dict(type='ReLU'). + dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is + 'default', it will be the same as `norm_cfg`. Default: 'default'. + dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: 'default'. + pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is + 'default', it will be the same as `norm_cfg`. Default: 'default'. + pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: 'default'. + kwargs (optional): Other shared arguments for depthwise and pointwise + ConvModule. See ConvModule for ref. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + dw_norm_cfg='default', + dw_act_cfg='default', + pw_norm_cfg='default', + pw_act_cfg='default', + **kwargs): + super(DepthwiseSeparableConvModule, self).__init__() + assert 'groups' not in kwargs, 'groups should not be specified' + + # if norm/activation config of depthwise/pointwise ConvModule is not + # specified, use default config. + dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg + dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg + pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg + pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg + + # depthwise convolution + self.depthwise_conv = ConvModule( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + norm_cfg=dw_norm_cfg, + act_cfg=dw_act_cfg, + **kwargs) + + self.pointwise_conv = ConvModule( + in_channels, + out_channels, + 1, + norm_cfg=pw_norm_cfg, + act_cfg=pw_act_cfg, + **kwargs) + + def forward(self, x): + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + return x diff --git a/annotator/uniformer/mmcv/cnn/bricks/drop.py b/annotator/uniformer/mmcv/cnn/bricks/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b4fccd457a0d51fb10c789df3c8537fe7b67c1 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/drop.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from annotator.uniformer.mmcv import build_from_cfg +from .registry import DROPOUT_LAYERS + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks). + + We follow the implementation + https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501 + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + # handle tensors with different dimensions, not just 4D tensors. + shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + output = x.div(keep_prob) * random_tensor.floor() + return output + + +@DROPOUT_LAYERS.register_module() +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks). + + We follow the implementation + https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501 + + Args: + drop_prob (float): Probability of the path to be zeroed. Default: 0.1 + """ + + def __init__(self, drop_prob=0.1): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +@DROPOUT_LAYERS.register_module() +class Dropout(nn.Dropout): + """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of + ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with + ``DropPath`` + + Args: + drop_prob (float): Probability of the elements to be + zeroed. Default: 0.5. + inplace (bool): Do the operation inplace or not. Default: False. + """ + + def __init__(self, drop_prob=0.5, inplace=False): + super().__init__(p=drop_prob, inplace=inplace) + + +def build_dropout(cfg, default_args=None): + """Builder for drop out layers.""" + return build_from_cfg(cfg, DROPOUT_LAYERS, default_args) diff --git a/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py b/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..988d9adf2f289ef223bd1c680a5ae1d3387f0269 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py @@ -0,0 +1,412 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import kaiming_init +from .registry import PLUGIN_LAYERS + + +@PLUGIN_LAYERS.register_module() +class GeneralizedAttention(nn.Module): + """GeneralizedAttention module. + + See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks' + (https://arxiv.org/abs/1711.07971) for details. + + Args: + in_channels (int): Channels of the input feature map. + spatial_range (int): The spatial range. -1 indicates no spatial range + constraint. Default: -1. + num_heads (int): The head number of empirical_attention module. + Default: 9. + position_embedding_dim (int): The position embedding dimension. + Default: -1. + position_magnitude (int): A multiplier acting on coord difference. + Default: 1. + kv_stride (int): The feature stride acting on key/value feature map. + Default: 2. + q_stride (int): The feature stride acting on query feature map. + Default: 1. + attention_type (str): A binary indicator string for indicating which + items in generalized empirical_attention module are used. + Default: '1111'. + + - '1000' indicates 'query and key content' (appr - appr) item, + - '0100' indicates 'query content and relative position' + (appr - position) item, + - '0010' indicates 'key content only' (bias - appr) item, + - '0001' indicates 'relative position only' (bias - position) item. + """ + + _abbr_ = 'gen_attention_block' + + def __init__(self, + in_channels, + spatial_range=-1, + num_heads=9, + position_embedding_dim=-1, + position_magnitude=1, + kv_stride=2, + q_stride=1, + attention_type='1111'): + + super(GeneralizedAttention, self).__init__() + + # hard range means local range for non-local operation + self.position_embedding_dim = ( + position_embedding_dim + if position_embedding_dim > 0 else in_channels) + + self.position_magnitude = position_magnitude + self.num_heads = num_heads + self.in_channels = in_channels + self.spatial_range = spatial_range + self.kv_stride = kv_stride + self.q_stride = q_stride + self.attention_type = [bool(int(_)) for _ in attention_type] + self.qk_embed_dim = in_channels // num_heads + out_c = self.qk_embed_dim * num_heads + + if self.attention_type[0] or self.attention_type[1]: + self.query_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_c, + kernel_size=1, + bias=False) + self.query_conv.kaiming_init = True + + if self.attention_type[0] or self.attention_type[2]: + self.key_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_c, + kernel_size=1, + bias=False) + self.key_conv.kaiming_init = True + + self.v_dim = in_channels // num_heads + self.value_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=self.v_dim * num_heads, + kernel_size=1, + bias=False) + self.value_conv.kaiming_init = True + + if self.attention_type[1] or self.attention_type[3]: + self.appr_geom_fc_x = nn.Linear( + self.position_embedding_dim // 2, out_c, bias=False) + self.appr_geom_fc_x.kaiming_init = True + + self.appr_geom_fc_y = nn.Linear( + self.position_embedding_dim // 2, out_c, bias=False) + self.appr_geom_fc_y.kaiming_init = True + + if self.attention_type[2]: + stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2) + appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv + self.appr_bias = nn.Parameter(appr_bias_value) + + if self.attention_type[3]: + stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2) + geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv + self.geom_bias = nn.Parameter(geom_bias_value) + + self.proj_conv = nn.Conv2d( + in_channels=self.v_dim * num_heads, + out_channels=in_channels, + kernel_size=1, + bias=True) + self.proj_conv.kaiming_init = True + self.gamma = nn.Parameter(torch.zeros(1)) + + if self.spatial_range >= 0: + # only works when non local is after 3*3 conv + if in_channels == 256: + max_len = 84 + elif in_channels == 512: + max_len = 42 + + max_len_kv = int((max_len - 1.0) / self.kv_stride + 1) + local_constraint_map = np.ones( + (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int) + for iy in range(max_len): + for ix in range(max_len): + local_constraint_map[ + iy, ix, + max((iy - self.spatial_range) // + self.kv_stride, 0):min((iy + self.spatial_range + + 1) // self.kv_stride + + 1, max_len), + max((ix - self.spatial_range) // + self.kv_stride, 0):min((ix + self.spatial_range + + 1) // self.kv_stride + + 1, max_len)] = 0 + + self.local_constraint_map = nn.Parameter( + torch.from_numpy(local_constraint_map).byte(), + requires_grad=False) + + if self.q_stride > 1: + self.q_downsample = nn.AvgPool2d( + kernel_size=1, stride=self.q_stride) + else: + self.q_downsample = None + + if self.kv_stride > 1: + self.kv_downsample = nn.AvgPool2d( + kernel_size=1, stride=self.kv_stride) + else: + self.kv_downsample = None + + self.init_weights() + + def get_position_embedding(self, + h, + w, + h_kv, + w_kv, + q_stride, + kv_stride, + device, + dtype, + feat_dim, + wave_length=1000): + # the default type of Tensor is float32, leading to type mismatch + # in fp16 mode. Cast it to support fp16 mode. + h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype) + h_idxs = h_idxs.view((h, 1)) * q_stride + + w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype) + w_idxs = w_idxs.view((w, 1)) * q_stride + + h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to( + device=device, dtype=dtype) + h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride + + w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to( + device=device, dtype=dtype) + w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride + + # (h, h_kv, 1) + h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0) + h_diff *= self.position_magnitude + + # (w, w_kv, 1) + w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0) + w_diff *= self.position_magnitude + + feat_range = torch.arange(0, feat_dim / 4).to( + device=device, dtype=dtype) + + dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype) + dim_mat = dim_mat**((4. / feat_dim) * feat_range) + dim_mat = dim_mat.view((1, 1, -1)) + + embedding_x = torch.cat( + ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2) + + embedding_y = torch.cat( + ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2) + + return embedding_x, embedding_y + + def forward(self, x_input): + num_heads = self.num_heads + + # use empirical_attention + if self.q_downsample is not None: + x_q = self.q_downsample(x_input) + else: + x_q = x_input + n, _, h, w = x_q.shape + + if self.kv_downsample is not None: + x_kv = self.kv_downsample(x_input) + else: + x_kv = x_input + _, _, h_kv, w_kv = x_kv.shape + + if self.attention_type[0] or self.attention_type[1]: + proj_query = self.query_conv(x_q).view( + (n, num_heads, self.qk_embed_dim, h * w)) + proj_query = proj_query.permute(0, 1, 3, 2) + + if self.attention_type[0] or self.attention_type[2]: + proj_key = self.key_conv(x_kv).view( + (n, num_heads, self.qk_embed_dim, h_kv * w_kv)) + + if self.attention_type[1] or self.attention_type[3]: + position_embed_x, position_embed_y = self.get_position_embedding( + h, w, h_kv, w_kv, self.q_stride, self.kv_stride, + x_input.device, x_input.dtype, self.position_embedding_dim) + # (n, num_heads, w, w_kv, dim) + position_feat_x = self.appr_geom_fc_x(position_embed_x).\ + view(1, w, w_kv, num_heads, self.qk_embed_dim).\ + permute(0, 3, 1, 2, 4).\ + repeat(n, 1, 1, 1, 1) + + # (n, num_heads, h, h_kv, dim) + position_feat_y = self.appr_geom_fc_y(position_embed_y).\ + view(1, h, h_kv, num_heads, self.qk_embed_dim).\ + permute(0, 3, 1, 2, 4).\ + repeat(n, 1, 1, 1, 1) + + position_feat_x /= math.sqrt(2) + position_feat_y /= math.sqrt(2) + + # accelerate for saliency only + if (np.sum(self.attention_type) == 1) and self.attention_type[2]: + appr_bias = self.appr_bias.\ + view(1, num_heads, 1, self.qk_embed_dim).\ + repeat(n, 1, 1, 1) + + energy = torch.matmul(appr_bias, proj_key).\ + view(n, num_heads, 1, h_kv * w_kv) + + h = 1 + w = 1 + else: + # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for + if not self.attention_type[0]: + energy = torch.zeros( + n, + num_heads, + h, + w, + h_kv, + w_kv, + dtype=x_input.dtype, + device=x_input.device) + + # attention_type[0]: appr - appr + # attention_type[1]: appr - position + # attention_type[2]: bias - appr + # attention_type[3]: bias - position + if self.attention_type[0] or self.attention_type[2]: + if self.attention_type[0] and self.attention_type[2]: + appr_bias = self.appr_bias.\ + view(1, num_heads, 1, self.qk_embed_dim) + energy = torch.matmul(proj_query + appr_bias, proj_key).\ + view(n, num_heads, h, w, h_kv, w_kv) + + elif self.attention_type[0]: + energy = torch.matmul(proj_query, proj_key).\ + view(n, num_heads, h, w, h_kv, w_kv) + + elif self.attention_type[2]: + appr_bias = self.appr_bias.\ + view(1, num_heads, 1, self.qk_embed_dim).\ + repeat(n, 1, 1, 1) + + energy += torch.matmul(appr_bias, proj_key).\ + view(n, num_heads, 1, 1, h_kv, w_kv) + + if self.attention_type[1] or self.attention_type[3]: + if self.attention_type[1] and self.attention_type[3]: + geom_bias = self.geom_bias.\ + view(1, num_heads, 1, self.qk_embed_dim) + + proj_query_reshape = (proj_query + geom_bias).\ + view(n, num_heads, h, w, self.qk_embed_dim) + + energy_x = torch.matmul( + proj_query_reshape.permute(0, 1, 3, 2, 4), + position_feat_x.permute(0, 1, 2, 4, 3)) + energy_x = energy_x.\ + permute(0, 1, 3, 2, 4).unsqueeze(4) + + energy_y = torch.matmul( + proj_query_reshape, + position_feat_y.permute(0, 1, 2, 4, 3)) + energy_y = energy_y.unsqueeze(5) + + energy += energy_x + energy_y + + elif self.attention_type[1]: + proj_query_reshape = proj_query.\ + view(n, num_heads, h, w, self.qk_embed_dim) + proj_query_reshape = proj_query_reshape.\ + permute(0, 1, 3, 2, 4) + position_feat_x_reshape = position_feat_x.\ + permute(0, 1, 2, 4, 3) + position_feat_y_reshape = position_feat_y.\ + permute(0, 1, 2, 4, 3) + + energy_x = torch.matmul(proj_query_reshape, + position_feat_x_reshape) + energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4) + + energy_y = torch.matmul(proj_query_reshape, + position_feat_y_reshape) + energy_y = energy_y.unsqueeze(5) + + energy += energy_x + energy_y + + elif self.attention_type[3]: + geom_bias = self.geom_bias.\ + view(1, num_heads, self.qk_embed_dim, 1).\ + repeat(n, 1, 1, 1) + + position_feat_x_reshape = position_feat_x.\ + view(n, num_heads, w*w_kv, self.qk_embed_dim) + + position_feat_y_reshape = position_feat_y.\ + view(n, num_heads, h * h_kv, self.qk_embed_dim) + + energy_x = torch.matmul(position_feat_x_reshape, geom_bias) + energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv) + + energy_y = torch.matmul(position_feat_y_reshape, geom_bias) + energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1) + + energy += energy_x + energy_y + + energy = energy.view(n, num_heads, h * w, h_kv * w_kv) + + if self.spatial_range >= 0: + cur_local_constraint_map = \ + self.local_constraint_map[:h, :w, :h_kv, :w_kv].\ + contiguous().\ + view(1, 1, h*w, h_kv*w_kv) + + energy = energy.masked_fill_(cur_local_constraint_map, + float('-inf')) + + attention = F.softmax(energy, 3) + + proj_value = self.value_conv(x_kv) + proj_value_reshape = proj_value.\ + view((n, num_heads, self.v_dim, h_kv * w_kv)).\ + permute(0, 1, 3, 2) + + out = torch.matmul(attention, proj_value_reshape).\ + permute(0, 1, 3, 2).\ + contiguous().\ + view(n, self.v_dim * self.num_heads, h, w) + + out = self.proj_conv(out) + + # output is downsampled, upsample back to input size + if self.q_downsample is not None: + out = F.interpolate( + out, + size=x_input.shape[2:], + mode='bilinear', + align_corners=False) + + out = self.gamma * out + x_input + return out + + def init_weights(self): + for m in self.modules(): + if hasattr(m, 'kaiming_init') and m.kaiming_init: + kaiming_init( + m, + mode='fan_in', + nonlinearity='leaky_relu', + bias=0, + distribution='uniform', + a=1) diff --git a/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py b/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py new file mode 100644 index 0000000000000000000000000000000000000000..30b1a3d6580cf0360710426fbea1f05acdf07b4b --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from .registry import ACTIVATION_LAYERS + + +@ACTIVATION_LAYERS.register_module() +class HSigmoid(nn.Module): + """Hard Sigmoid Module. Apply the hard sigmoid function: + Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value) + Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1) + + Args: + bias (float): Bias of the input feature map. Default: 1.0. + divisor (float): Divisor of the input feature map. Default: 2.0. + min_value (float): Lower bound value. Default: 0.0. + max_value (float): Upper bound value. Default: 1.0. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0): + super(HSigmoid, self).__init__() + self.bias = bias + self.divisor = divisor + assert self.divisor != 0 + self.min_value = min_value + self.max_value = max_value + + def forward(self, x): + x = (x + self.bias) / self.divisor + + return x.clamp_(self.min_value, self.max_value) diff --git a/annotator/uniformer/mmcv/cnn/bricks/hswish.py b/annotator/uniformer/mmcv/cnn/bricks/hswish.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0c090ff037c99ee6c5c84c4592e87beae02208 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/hswish.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from .registry import ACTIVATION_LAYERS + + +@ACTIVATION_LAYERS.register_module() +class HSwish(nn.Module): + """Hard Swish Module. + + This module applies the hard swish function: + + .. math:: + Hswish(x) = x * ReLU6(x + 3) / 6 + + Args: + inplace (bool): can optionally do the operation in-place. + Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, inplace=False): + super(HSwish, self).__init__() + self.act = nn.ReLU6(inplace) + + def forward(self, x): + return x * self.act(x + 3) / 6 diff --git a/annotator/uniformer/mmcv/cnn/bricks/non_local.py b/annotator/uniformer/mmcv/cnn/bricks/non_local.py new file mode 100644 index 0000000000000000000000000000000000000000..92d00155ef275c1201ea66bba30470a1785cc5d7 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/non_local.py @@ -0,0 +1,306 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta + +import torch +import torch.nn as nn + +from ..utils import constant_init, normal_init +from .conv_module import ConvModule +from .registry import PLUGIN_LAYERS + + +class _NonLocalNd(nn.Module, metaclass=ABCMeta): + """Basic Non-local module. + + This module is proposed in + "Non-local Neural Networks" + Paper reference: https://arxiv.org/abs/1711.07971 + Code reference: https://github.com/AlexHex7/Non-local_pytorch + + Args: + in_channels (int): Channels of the input feature map. + reduction (int): Channel reduction ratio. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`. + Default: True. + conv_cfg (None | dict): The config dict for convolution layers. + If not specified, it will use `nn.Conv2d` for convolution layers. + Default: None. + norm_cfg (None | dict): The config dict for normalization layers. + Default: None. (This parameter is only applicable to conv_out.) + mode (str): Options are `gaussian`, `concatenation`, + `embedded_gaussian` and `dot_product`. Default: embedded_gaussian. + """ + + def __init__(self, + in_channels, + reduction=2, + use_scale=True, + conv_cfg=None, + norm_cfg=None, + mode='embedded_gaussian', + **kwargs): + super(_NonLocalNd, self).__init__() + self.in_channels = in_channels + self.reduction = reduction + self.use_scale = use_scale + self.inter_channels = max(in_channels // reduction, 1) + self.mode = mode + + if mode not in [ + 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation' + ]: + raise ValueError("Mode should be in 'gaussian', 'concatenation', " + f"'embedded_gaussian' or 'dot_product', but got " + f'{mode} instead.') + + # g, theta, phi are defaulted as `nn.ConvNd`. + # Here we use ConvModule for potential usage. + self.g = ConvModule( + self.in_channels, + self.inter_channels, + kernel_size=1, + conv_cfg=conv_cfg, + act_cfg=None) + self.conv_out = ConvModule( + self.inter_channels, + self.in_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + if self.mode != 'gaussian': + self.theta = ConvModule( + self.in_channels, + self.inter_channels, + kernel_size=1, + conv_cfg=conv_cfg, + act_cfg=None) + self.phi = ConvModule( + self.in_channels, + self.inter_channels, + kernel_size=1, + conv_cfg=conv_cfg, + act_cfg=None) + + if self.mode == 'concatenation': + self.concat_project = ConvModule( + self.inter_channels * 2, + 1, + kernel_size=1, + stride=1, + padding=0, + bias=False, + act_cfg=dict(type='ReLU')) + + self.init_weights(**kwargs) + + def init_weights(self, std=0.01, zeros_init=True): + if self.mode != 'gaussian': + for m in [self.g, self.theta, self.phi]: + normal_init(m.conv, std=std) + else: + normal_init(self.g.conv, std=std) + if zeros_init: + if self.conv_out.norm_cfg is None: + constant_init(self.conv_out.conv, 0) + else: + constant_init(self.conv_out.norm, 0) + else: + if self.conv_out.norm_cfg is None: + normal_init(self.conv_out.conv, std=std) + else: + normal_init(self.conv_out.norm, std=std) + + def gaussian(self, theta_x, phi_x): + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + pairwise_weight = pairwise_weight.softmax(dim=-1) + return pairwise_weight + + def embedded_gaussian(self, theta_x, phi_x): + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + if self.use_scale: + # theta_x.shape[-1] is `self.inter_channels` + pairwise_weight /= theta_x.shape[-1]**0.5 + pairwise_weight = pairwise_weight.softmax(dim=-1) + return pairwise_weight + + def dot_product(self, theta_x, phi_x): + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + pairwise_weight /= pairwise_weight.shape[-1] + return pairwise_weight + + def concatenation(self, theta_x, phi_x): + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + h = theta_x.size(2) + w = phi_x.size(3) + theta_x = theta_x.repeat(1, 1, 1, w) + phi_x = phi_x.repeat(1, 1, h, 1) + + concat_feature = torch.cat([theta_x, phi_x], dim=1) + pairwise_weight = self.concat_project(concat_feature) + n, _, h, w = pairwise_weight.size() + pairwise_weight = pairwise_weight.view(n, h, w) + pairwise_weight /= pairwise_weight.shape[-1] + + return pairwise_weight + + def forward(self, x): + # Assume `reduction = 1`, then `inter_channels = C` + # or `inter_channels = C` when `mode="gaussian"` + + # NonLocal1d x: [N, C, H] + # NonLocal2d x: [N, C, H, W] + # NonLocal3d x: [N, C, T, H, W] + n = x.size(0) + + # NonLocal1d g_x: [N, H, C] + # NonLocal2d g_x: [N, HxW, C] + # NonLocal3d g_x: [N, TxHxW, C] + g_x = self.g(x).view(n, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H] + # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW] + # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW] + if self.mode == 'gaussian': + theta_x = x.view(n, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + if self.sub_sample: + phi_x = self.phi(x).view(n, self.in_channels, -1) + else: + phi_x = x.view(n, self.in_channels, -1) + elif self.mode == 'concatenation': + theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) + phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) + else: + theta_x = self.theta(x).view(n, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(n, self.inter_channels, -1) + + pairwise_func = getattr(self, self.mode) + # NonLocal1d pairwise_weight: [N, H, H] + # NonLocal2d pairwise_weight: [N, HxW, HxW] + # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] + pairwise_weight = pairwise_func(theta_x, phi_x) + + # NonLocal1d y: [N, H, C] + # NonLocal2d y: [N, HxW, C] + # NonLocal3d y: [N, TxHxW, C] + y = torch.matmul(pairwise_weight, g_x) + # NonLocal1d y: [N, C, H] + # NonLocal2d y: [N, C, H, W] + # NonLocal3d y: [N, C, T, H, W] + y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, + *x.size()[2:]) + + output = x + self.conv_out(y) + + return output + + +class NonLocal1d(_NonLocalNd): + """1D Non-local module. + + Args: + in_channels (int): Same as `NonLocalND`. + sub_sample (bool): Whether to apply max pooling after pairwise + function (Note that the `sub_sample` is applied on spatial only). + Default: False. + conv_cfg (None | dict): Same as `NonLocalND`. + Default: dict(type='Conv1d'). + """ + + def __init__(self, + in_channels, + sub_sample=False, + conv_cfg=dict(type='Conv1d'), + **kwargs): + super(NonLocal1d, self).__init__( + in_channels, conv_cfg=conv_cfg, **kwargs) + + self.sub_sample = sub_sample + + if sub_sample: + max_pool_layer = nn.MaxPool1d(kernel_size=2) + self.g = nn.Sequential(self.g, max_pool_layer) + if self.mode != 'gaussian': + self.phi = nn.Sequential(self.phi, max_pool_layer) + else: + self.phi = max_pool_layer + + +@PLUGIN_LAYERS.register_module() +class NonLocal2d(_NonLocalNd): + """2D Non-local module. + + Args: + in_channels (int): Same as `NonLocalND`. + sub_sample (bool): Whether to apply max pooling after pairwise + function (Note that the `sub_sample` is applied on spatial only). + Default: False. + conv_cfg (None | dict): Same as `NonLocalND`. + Default: dict(type='Conv2d'). + """ + + _abbr_ = 'nonlocal_block' + + def __init__(self, + in_channels, + sub_sample=False, + conv_cfg=dict(type='Conv2d'), + **kwargs): + super(NonLocal2d, self).__init__( + in_channels, conv_cfg=conv_cfg, **kwargs) + + self.sub_sample = sub_sample + + if sub_sample: + max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) + self.g = nn.Sequential(self.g, max_pool_layer) + if self.mode != 'gaussian': + self.phi = nn.Sequential(self.phi, max_pool_layer) + else: + self.phi = max_pool_layer + + +class NonLocal3d(_NonLocalNd): + """3D Non-local module. + + Args: + in_channels (int): Same as `NonLocalND`. + sub_sample (bool): Whether to apply max pooling after pairwise + function (Note that the `sub_sample` is applied on spatial only). + Default: False. + conv_cfg (None | dict): Same as `NonLocalND`. + Default: dict(type='Conv3d'). + """ + + def __init__(self, + in_channels, + sub_sample=False, + conv_cfg=dict(type='Conv3d'), + **kwargs): + super(NonLocal3d, self).__init__( + in_channels, conv_cfg=conv_cfg, **kwargs) + self.sub_sample = sub_sample + + if sub_sample: + max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) + self.g = nn.Sequential(self.g, max_pool_layer) + if self.mode != 'gaussian': + self.phi = nn.Sequential(self.phi, max_pool_layer) + else: + self.phi = max_pool_layer diff --git a/annotator/uniformer/mmcv/cnn/bricks/norm.py b/annotator/uniformer/mmcv/cnn/bricks/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..408f4b42731b19a3beeef68b6a5e610d0bbc18b3 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/norm.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect + +import torch.nn as nn + +from annotator.uniformer.mmcv.utils import is_tuple_of +from annotator.uniformer.mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm +from .registry import NORM_LAYERS + +NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d) +NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d) +NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d) +NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d) +NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm) +NORM_LAYERS.register_module('GN', module=nn.GroupNorm) +NORM_LAYERS.register_module('LN', module=nn.LayerNorm) +NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d) +NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d) +NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d) +NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d) + + +def infer_abbr(class_type): + """Infer abbreviation from the class name. + + When we build a norm layer with `build_norm_layer()`, we want to preserve + the norm type in variable names, e.g, self.bn1, self.gn. This method will + infer the abbreviation to map class types to abbreviations. + + Rule 1: If the class has the property "_abbr_", return the property. + Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or + InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and + "in" respectively. + Rule 3: If the class name contains "batch", "group", "layer" or "instance", + the abbreviation of this layer will be "bn", "gn", "ln" and "in" + respectively. + Rule 4: Otherwise, the abbreviation falls back to "norm". + + Args: + class_type (type): The norm layer type. + + Returns: + str: The inferred abbreviation. + """ + if not inspect.isclass(class_type): + raise TypeError( + f'class_type must be a type, but got {type(class_type)}') + if hasattr(class_type, '_abbr_'): + return class_type._abbr_ + if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN + return 'in' + elif issubclass(class_type, _BatchNorm): + return 'bn' + elif issubclass(class_type, nn.GroupNorm): + return 'gn' + elif issubclass(class_type, nn.LayerNorm): + return 'ln' + else: + class_name = class_type.__name__.lower() + if 'batch' in class_name: + return 'bn' + elif 'group' in class_name: + return 'gn' + elif 'layer' in class_name: + return 'ln' + elif 'instance' in class_name: + return 'in' + else: + return 'norm_layer' + + +def build_norm_layer(cfg, num_features, postfix=''): + """Build normalization layer. + + Args: + cfg (dict): The norm layer config, which should contain: + + - type (str): Layer type. + - layer args: Args needed to instantiate a norm layer. + - requires_grad (bool, optional): Whether stop gradient updates. + num_features (int): Number of input channels. + postfix (int | str): The postfix to be appended into norm abbreviation + to create named layer. + + Returns: + (str, nn.Module): The first element is the layer name consisting of + abbreviation and postfix, e.g., bn1, gn. The second element is the + created norm layer. + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in NORM_LAYERS: + raise KeyError(f'Unrecognized norm type {layer_type}') + + norm_layer = NORM_LAYERS.get(layer_type) + abbr = infer_abbr(norm_layer) + + assert isinstance(postfix, (int, str)) + name = abbr + str(postfix) + + requires_grad = cfg_.pop('requires_grad', True) + cfg_.setdefault('eps', 1e-5) + if layer_type != 'GN': + layer = norm_layer(num_features, **cfg_) + if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): + layer._specify_ddp_gpu_num(1) + else: + assert 'num_groups' in cfg_ + layer = norm_layer(num_channels=num_features, **cfg_) + + for param in layer.parameters(): + param.requires_grad = requires_grad + + return name, layer + + +def is_norm(layer, exclude=None): + """Check if a layer is a normalization layer. + + Args: + layer (nn.Module): The layer to be checked. + exclude (type | tuple[type]): Types to be excluded. + + Returns: + bool: Whether the layer is a norm layer. + """ + if exclude is not None: + if not isinstance(exclude, tuple): + exclude = (exclude, ) + if not is_tuple_of(exclude, type): + raise TypeError( + f'"exclude" must be either None or type or a tuple of types, ' + f'but got {type(exclude)}: {exclude}') + + if exclude and isinstance(layer, exclude): + return False + + all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm) + return isinstance(layer, all_norm_bases) diff --git a/annotator/uniformer/mmcv/cnn/bricks/padding.py b/annotator/uniformer/mmcv/cnn/bricks/padding.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ac6b28a1789bd551c613a7d3e7b622433ac7ec --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/padding.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from .registry import PADDING_LAYERS + +PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d) +PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d) +PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d) + + +def build_padding_layer(cfg, *args, **kwargs): + """Build padding layer. + + Args: + cfg (None or dict): The padding layer config, which should contain: + - type (str): Layer type. + - layer args: Args needed to instantiate a padding layer. + + Returns: + nn.Module: Created padding layer. + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + + cfg_ = cfg.copy() + padding_type = cfg_.pop('type') + if padding_type not in PADDING_LAYERS: + raise KeyError(f'Unrecognized padding type {padding_type}.') + else: + padding_layer = PADDING_LAYERS.get(padding_type) + + layer = padding_layer(*args, **kwargs, **cfg_) + + return layer diff --git a/annotator/uniformer/mmcv/cnn/bricks/plugin.py b/annotator/uniformer/mmcv/cnn/bricks/plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..07c010d4053174dd41107aa654ea67e82b46a25c --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/plugin.py @@ -0,0 +1,88 @@ +import inspect +import platform + +from .registry import PLUGIN_LAYERS + +if platform.system() == 'Windows': + import regex as re +else: + import re + + +def infer_abbr(class_type): + """Infer abbreviation from the class name. + + This method will infer the abbreviation to map class types to + abbreviations. + + Rule 1: If the class has the property "abbr", return the property. + Rule 2: Otherwise, the abbreviation falls back to snake case of class + name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``. + + Args: + class_type (type): The norm layer type. + + Returns: + str: The inferred abbreviation. + """ + + def camel2snack(word): + """Convert camel case word into snack case. + + Modified from `inflection lib + `_. + + Example:: + + >>> camel2snack("FancyBlock") + 'fancy_block' + """ + + word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word) + word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word) + word = word.replace('-', '_') + return word.lower() + + if not inspect.isclass(class_type): + raise TypeError( + f'class_type must be a type, but got {type(class_type)}') + if hasattr(class_type, '_abbr_'): + return class_type._abbr_ + else: + return camel2snack(class_type.__name__) + + +def build_plugin_layer(cfg, postfix='', **kwargs): + """Build plugin layer. + + Args: + cfg (None or dict): cfg should contain: + type (str): identify plugin layer type. + layer args: args needed to instantiate a plugin layer. + postfix (int, str): appended into norm abbreviation to + create named layer. Default: ''. + + Returns: + tuple[str, nn.Module]: + name (str): abbreviation + postfix + layer (nn.Module): created plugin layer + """ + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in PLUGIN_LAYERS: + raise KeyError(f'Unrecognized plugin type {layer_type}') + + plugin_layer = PLUGIN_LAYERS.get(layer_type) + abbr = infer_abbr(plugin_layer) + + assert isinstance(postfix, (int, str)) + name = abbr + str(postfix) + + layer = plugin_layer(**kwargs, **cfg_) + + return name, layer diff --git a/annotator/uniformer/mmcv/cnn/bricks/registry.py b/annotator/uniformer/mmcv/cnn/bricks/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..39eabc58db4b5954478a2ac1ab91cea5e45ab055 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/registry.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from annotator.uniformer.mmcv.utils import Registry + +CONV_LAYERS = Registry('conv layer') +NORM_LAYERS = Registry('norm layer') +ACTIVATION_LAYERS = Registry('activation layer') +PADDING_LAYERS = Registry('padding layer') +UPSAMPLE_LAYERS = Registry('upsample layer') +PLUGIN_LAYERS = Registry('plugin layer') + +DROPOUT_LAYERS = Registry('drop out layers') +POSITIONAL_ENCODING = Registry('position encoding') +ATTENTION = Registry('attention') +FEEDFORWARD_NETWORK = Registry('feed-forward Network') +TRANSFORMER_LAYER = Registry('transformerLayer') +TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence') diff --git a/annotator/uniformer/mmcv/cnn/bricks/scale.py b/annotator/uniformer/mmcv/cnn/bricks/scale.py new file mode 100644 index 0000000000000000000000000000000000000000..c905fffcc8bf998d18d94f927591963c428025e2 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/scale.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +class Scale(nn.Module): + """A learnable scale parameter. + + This layer scales the input by a learnable factor. It multiplies a + learnable scale parameter of shape (1,) with input of any shape. + + Args: + scale (float): Initial value of scale factor. Default: 1.0 + """ + + def __init__(self, scale=1.0): + super(Scale, self).__init__() + self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) + + def forward(self, x): + return x * self.scale diff --git a/annotator/uniformer/mmcv/cnn/bricks/swish.py b/annotator/uniformer/mmcv/cnn/bricks/swish.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ca8ed7b749413f011ae54aac0cab27e6f0b51f --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/swish.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from .registry import ACTIVATION_LAYERS + + +@ACTIVATION_LAYERS.register_module() +class Swish(nn.Module): + """Swish Module. + + This module applies the swish function: + + .. math:: + Swish(x) = x * Sigmoid(x) + + Returns: + Tensor: The output tensor. + """ + + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + return x * torch.sigmoid(x) diff --git a/annotator/uniformer/mmcv/cnn/bricks/transformer.py b/annotator/uniformer/mmcv/cnn/bricks/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e61ae0dd941a7be00b3e41a3de833ec50470a45f --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/transformer.py @@ -0,0 +1,595 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings + +import torch +import torch.nn as nn + +from annotator.uniformer.mmcv import ConfigDict, deprecated_api_warning +from annotator.uniformer.mmcv.cnn import Linear, build_activation_layer, build_norm_layer +from annotator.uniformer.mmcv.runner.base_module import BaseModule, ModuleList, Sequential +from annotator.uniformer.mmcv.utils import build_from_cfg +from .drop import build_dropout +from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING, + TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) + +# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file +try: + from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401 + warnings.warn( + ImportWarning( + '``MultiScaleDeformableAttention`` has been moved to ' + '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501 + '``from annotator.uniformer.mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501 + 'to ``from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501 + )) + +except ImportError: + warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from ' + '``mmcv.ops.multi_scale_deform_attn``, ' + 'You should install ``mmcv-full`` if you need this module. ') + + +def build_positional_encoding(cfg, default_args=None): + """Builder for Position Encoding.""" + return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args) + + +def build_attention(cfg, default_args=None): + """Builder for attention.""" + return build_from_cfg(cfg, ATTENTION, default_args) + + +def build_feedforward_network(cfg, default_args=None): + """Builder for feed-forward network (FFN).""" + return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args) + + +def build_transformer_layer(cfg, default_args=None): + """Builder for transformer layer.""" + return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args) + + +def build_transformer_layer_sequence(cfg, default_args=None): + """Builder for transformer encoder and transformer decoder.""" + return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args) + + +@ATTENTION.register_module() +class MultiheadAttention(BaseModule): + """A wrapper for ``torch.nn.MultiheadAttention``. + + This module implements MultiheadAttention with identity connection, + and positional encoding is also passed as input. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): When it is True, Key, Query and Value are shape of + (batch, n, embed_dim), otherwise (n, batch, embed_dim). + Default to False. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='Dropout', drop_prob=0.), + init_cfg=None, + batch_first=False, + **kwargs): + super(MultiheadAttention, self).__init__(init_cfg) + if 'dropout' in kwargs: + warnings.warn('The arguments `dropout` in MultiheadAttention ' + 'has been deprecated, now you can separately ' + 'set `attn_drop`(float), proj_drop(float), ' + 'and `dropout_layer`(dict) ') + attn_drop = kwargs['dropout'] + dropout_layer['drop_prob'] = kwargs.pop('dropout') + + self.embed_dims = embed_dims + self.num_heads = num_heads + self.batch_first = batch_first + + self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, + **kwargs) + + self.proj_drop = nn.Dropout(proj_drop) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + + @deprecated_api_warning({'residual': 'identity'}, + cls_name='MultiheadAttention') + def forward(self, + query, + key=None, + value=None, + identity=None, + query_pos=None, + key_pos=None, + attn_mask=None, + key_padding_mask=None, + **kwargs): + """Forward function for `MultiheadAttention`. + + **kwargs allow passing a more general data flow when combining + with other operations in `transformerlayer`. + + Args: + query (Tensor): The input query with shape [num_queries, bs, + embed_dims] if self.batch_first is False, else + [bs, num_queries embed_dims]. + key (Tensor): The key tensor with shape [num_keys, bs, + embed_dims] if self.batch_first is False, else + [bs, num_keys, embed_dims] . + If None, the ``query`` will be used. Defaults to None. + value (Tensor): The value tensor with same shape as `key`. + Same in `nn.MultiheadAttention.forward`. Defaults to None. + If None, the `key` will be used. + identity (Tensor): This tensor, with the same shape as x, + will be used for the identity link. + If None, `x` will be used. Defaults to None. + query_pos (Tensor): The positional encoding for query, with + the same shape as `x`. If not None, it will + be added to `x` before forward function. Defaults to None. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. Defaults to None. If not None, it will + be added to `key` before forward function. If None, and + `query_pos` has the same shape as `key`, then `query_pos` + will be used for `key_pos`. Defaults to None. + attn_mask (Tensor): ByteTensor mask with shape [num_queries, + num_keys]. Same in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. + Defaults to None. + + Returns: + Tensor: forwarded results with shape + [num_queries, bs, embed_dims] + if self.batch_first is False, else + [bs, num_queries embed_dims]. + """ + + if key is None: + key = query + if value is None: + value = key + if identity is None: + identity = query + if key_pos is None: + if query_pos is not None: + # use query_pos if key_pos is not available + if query_pos.shape == key.shape: + key_pos = query_pos + else: + warnings.warn(f'position encoding of key is' + f'missing in {self.__class__.__name__}.') + if query_pos is not None: + query = query + query_pos + if key_pos is not None: + key = key + key_pos + + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + out = self.attn( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask)[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + +@FEEDFORWARD_NETWORK.register_module() +class FFN(BaseModule): + """Implements feed-forward networks (FFNs) with identity connection. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + @deprecated_api_warning( + { + 'dropout': 'ffn_drop', + 'add_residual': 'add_identity' + }, + cls_name='FFN') + def __init__(self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0., + dropout_layer=None, + add_identity=True, + init_cfg=None, + **kwargs): + super(FFN, self).__init__(init_cfg) + assert num_fcs >= 2, 'num_fcs should be no less ' \ + f'than 2. got {num_fcs}.' + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + Sequential( + Linear(in_channels, feedforward_channels), self.activate, + nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') + def forward(self, x, identity=None): + """Forward function for `FFN`. + + The function would add x to the output tensor if residue is None. + """ + out = self.layers(x) + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +@TRANSFORMER_LAYER.register_module() +class BaseTransformerLayer(BaseModule): + """Base `TransformerLayer` for vision transformer. + + It can be built from `mmcv.ConfigDict` and support more flexible + customization, for example, using any number of `FFN or LN ` and + use different kinds of `attention` by specifying a list of `ConfigDict` + named `attn_cfgs`. It is worth mentioning that it supports `prenorm` + when you specifying `norm` as the first element of `operation_order`. + More details about the `prenorm`: `On Layer Normalization in the + Transformer Architecture `_ . + + Args: + attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): + Configs for `self_attention` or `cross_attention` modules, + The order of the configs in the list should be consistent with + corresponding attentions in operation_order. + If it is a dict, all of the attention modules in operation_order + will be built with this config. Default: None. + ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): + Configs for FFN, The order of the configs in the list should be + consistent with corresponding ffn in operation_order. + If it is a dict, all of the attention modules in operation_order + will be built with this config. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). + Support `prenorm` when you specifying first element as `norm`. + Default:None. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Key, Query and Value are shape + of (batch, n, embed_dim) + or (n, batch, embed_dim). Default to False. + """ + + def __init__(self, + attn_cfgs=None, + ffn_cfgs=dict( + type='FFN', + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True), + ), + operation_order=None, + norm_cfg=dict(type='LN'), + init_cfg=None, + batch_first=False, + **kwargs): + + deprecated_args = dict( + feedforward_channels='feedforward_channels', + ffn_dropout='ffn_drop', + ffn_num_fcs='num_fcs') + for ori_name, new_name in deprecated_args.items(): + if ori_name in kwargs: + warnings.warn( + f'The arguments `{ori_name}` in BaseTransformerLayer ' + f'has been deprecated, now you should set `{new_name}` ' + f'and other FFN related arguments ' + f'to a dict named `ffn_cfgs`. ') + ffn_cfgs[new_name] = kwargs[ori_name] + + super(BaseTransformerLayer, self).__init__(init_cfg) + + self.batch_first = batch_first + + assert set(operation_order) & set( + ['self_attn', 'norm', 'ffn', 'cross_attn']) == \ + set(operation_order), f'The operation_order of' \ + f' {self.__class__.__name__} should ' \ + f'contains all four operation type ' \ + f"{['self_attn', 'norm', 'ffn', 'cross_attn']}" + + num_attn = operation_order.count('self_attn') + operation_order.count( + 'cross_attn') + if isinstance(attn_cfgs, dict): + attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] + else: + assert num_attn == len(attn_cfgs), f'The length ' \ + f'of attn_cfg {num_attn} is ' \ + f'not consistent with the number of attention' \ + f'in operation_order {operation_order}.' + + self.num_attn = num_attn + self.operation_order = operation_order + self.norm_cfg = norm_cfg + self.pre_norm = operation_order[0] == 'norm' + self.attentions = ModuleList() + + index = 0 + for operation_name in operation_order: + if operation_name in ['self_attn', 'cross_attn']: + if 'batch_first' in attn_cfgs[index]: + assert self.batch_first == attn_cfgs[index]['batch_first'] + else: + attn_cfgs[index]['batch_first'] = self.batch_first + attention = build_attention(attn_cfgs[index]) + # Some custom attentions used as `self_attn` + # or `cross_attn` can have different behavior. + attention.operation_name = operation_name + self.attentions.append(attention) + index += 1 + + self.embed_dims = self.attentions[0].embed_dims + + self.ffns = ModuleList() + num_ffns = operation_order.count('ffn') + if isinstance(ffn_cfgs, dict): + ffn_cfgs = ConfigDict(ffn_cfgs) + if isinstance(ffn_cfgs, dict): + ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] + assert len(ffn_cfgs) == num_ffns + for ffn_index in range(num_ffns): + if 'embed_dims' not in ffn_cfgs[ffn_index]: + ffn_cfgs['embed_dims'] = self.embed_dims + else: + assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims + self.ffns.append( + build_feedforward_network(ffn_cfgs[ffn_index], + dict(type='FFN'))) + + self.norms = ModuleList() + num_norms = operation_order.count('norm') + for _ in range(num_norms): + self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) + + def forward(self, + query, + key=None, + value=None, + query_pos=None, + key_pos=None, + attn_masks=None, + query_key_padding_mask=None, + key_padding_mask=None, + **kwargs): + """Forward function for `TransformerDecoderLayer`. + + **kwargs contains some specific arguments of attentions. + + Args: + query (Tensor): The input query with shape + [num_queries, bs, embed_dims] if + self.batch_first is False, else + [bs, num_queries embed_dims]. + key (Tensor): The key tensor with shape [num_keys, bs, + embed_dims] if self.batch_first is False, else + [bs, num_keys, embed_dims] . + value (Tensor): The value tensor with same shape as `key`. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + key_pos (Tensor): The positional encoding for `key`. + Default: None. + attn_masks (List[Tensor] | None): 2D Tensor used in + calculation of corresponding attention. The length of + it should equal to the number of `attention` in + `operation_order`. Default: None. + query_key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_queries]. Only used in `self_attn` layer. + Defaults to None. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_keys]. Default: None. + + Returns: + Tensor: forwarded results with shape [num_queries, bs, embed_dims]. + """ + + norm_index = 0 + attn_index = 0 + ffn_index = 0 + identity = query + if attn_masks is None: + attn_masks = [None for _ in range(self.num_attn)] + elif isinstance(attn_masks, torch.Tensor): + attn_masks = [ + copy.deepcopy(attn_masks) for _ in range(self.num_attn) + ] + warnings.warn(f'Use same attn_mask in all attentions in ' + f'{self.__class__.__name__} ') + else: + assert len(attn_masks) == self.num_attn, f'The length of ' \ + f'attn_masks {len(attn_masks)} must be equal ' \ + f'to the number of attention in ' \ + f'operation_order {self.num_attn}' + + for layer in self.operation_order: + if layer == 'self_attn': + temp_key = temp_value = query + query = self.attentions[attn_index]( + query, + temp_key, + temp_value, + identity if self.pre_norm else None, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=attn_masks[attn_index], + key_padding_mask=query_key_padding_mask, + **kwargs) + attn_index += 1 + identity = query + + elif layer == 'norm': + query = self.norms[norm_index](query) + norm_index += 1 + + elif layer == 'cross_attn': + query = self.attentions[attn_index]( + query, + key, + value, + identity if self.pre_norm else None, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=attn_masks[attn_index], + key_padding_mask=key_padding_mask, + **kwargs) + attn_index += 1 + identity = query + + elif layer == 'ffn': + query = self.ffns[ffn_index]( + query, identity if self.pre_norm else None) + ffn_index += 1 + + return query + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class TransformerLayerSequence(BaseModule): + """Base class for TransformerEncoder and TransformerDecoder in vision + transformer. + + As base-class of Encoder and Decoder in vision transformer. + Support customization such as specifying different kind + of `transformer_layer` in `transformer_coder`. + + Args: + transformerlayer (list[obj:`mmcv.ConfigDict`] | + obj:`mmcv.ConfigDict`): Config of transformerlayer + in TransformerCoder. If it is obj:`mmcv.ConfigDict`, + it would be repeated `num_layer` times to a + list[`mmcv.ConfigDict`]. Default: None. + num_layers (int): The number of `TransformerLayer`. Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): + super(TransformerLayerSequence, self).__init__(init_cfg) + if isinstance(transformerlayers, dict): + transformerlayers = [ + copy.deepcopy(transformerlayers) for _ in range(num_layers) + ] + else: + assert isinstance(transformerlayers, list) and \ + len(transformerlayers) == num_layers + self.num_layers = num_layers + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append(build_transformer_layer(transformerlayers[i])) + self.embed_dims = self.layers[0].embed_dims + self.pre_norm = self.layers[0].pre_norm + + def forward(self, + query, + key, + value, + query_pos=None, + key_pos=None, + attn_masks=None, + query_key_padding_mask=None, + key_padding_mask=None, + **kwargs): + """Forward function for `TransformerCoder`. + + Args: + query (Tensor): Input query with shape + `(num_queries, bs, embed_dims)`. + key (Tensor): The key tensor with shape + `(num_keys, bs, embed_dims)`. + value (Tensor): The value tensor with shape + `(num_keys, bs, embed_dims)`. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + key_pos (Tensor): The positional encoding for `key`. + Default: None. + attn_masks (List[Tensor], optional): Each element is 2D Tensor + which is used in calculation of corresponding attention in + operation_order. Default: None. + query_key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_queries]. Only used in self-attention + Default: None. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_keys]. Default: None. + + Returns: + Tensor: results with shape [num_queries, bs, embed_dims]. + """ + for layer in self.layers: + query = layer( + query, + key, + value, + query_pos=query_pos, + key_pos=key_pos, + attn_masks=attn_masks, + query_key_padding_mask=query_key_padding_mask, + key_padding_mask=key_padding_mask, + **kwargs) + return query diff --git a/annotator/uniformer/mmcv/cnn/bricks/upsample.py b/annotator/uniformer/mmcv/cnn/bricks/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a353767d0ce8518f0d7289bed10dba0178ed12 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/upsample.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import xavier_init +from .registry import UPSAMPLE_LAYERS + +UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample) +UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample) + + +@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle') +class PixelShufflePack(nn.Module): + """Pixel Shuffle upsample layer. + + This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to + achieve a simple upsampling with pixel shuffle. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + scale_factor (int): Upsample ratio. + upsample_kernel (int): Kernel size of the conv layer to expand the + channels. + """ + + def __init__(self, in_channels, out_channels, scale_factor, + upsample_kernel): + super(PixelShufflePack, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.scale_factor = scale_factor + self.upsample_kernel = upsample_kernel + self.upsample_conv = nn.Conv2d( + self.in_channels, + self.out_channels * scale_factor * scale_factor, + self.upsample_kernel, + padding=(self.upsample_kernel - 1) // 2) + self.init_weights() + + def init_weights(self): + xavier_init(self.upsample_conv, distribution='uniform') + + def forward(self, x): + x = self.upsample_conv(x) + x = F.pixel_shuffle(x, self.scale_factor) + return x + + +def build_upsample_layer(cfg, *args, **kwargs): + """Build upsample layer. + + Args: + cfg (dict): The upsample layer config, which should contain: + + - type (str): Layer type. + - scale_factor (int): Upsample ratio, which is not applicable to + deconv. + - layer args: Args needed to instantiate a upsample layer. + args (argument list): Arguments passed to the ``__init__`` + method of the corresponding conv layer. + kwargs (keyword arguments): Keyword arguments passed to the + ``__init__`` method of the corresponding conv layer. + + Returns: + nn.Module: Created upsample layer. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + raise KeyError( + f'the cfg dict must contain the key "type", but got {cfg}') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in UPSAMPLE_LAYERS: + raise KeyError(f'Unrecognized upsample type {layer_type}') + else: + upsample = UPSAMPLE_LAYERS.get(layer_type) + + if upsample is nn.Upsample: + cfg_['mode'] = layer_type + layer = upsample(*args, **kwargs, **cfg_) + return layer diff --git a/annotator/uniformer/mmcv/cnn/bricks/wrappers.py b/annotator/uniformer/mmcv/cnn/bricks/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..8aebf67bf52355a513f21756ee74fe510902d075 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/bricks/wrappers.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501 + +Wrap some nn modules to support empty tensor input. Currently, these wrappers +are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask +heads are trained on only positive RoIs. +""" +import math + +import torch +import torch.nn as nn +from torch.nn.modules.utils import _pair, _triple + +from .registry import CONV_LAYERS, UPSAMPLE_LAYERS + +if torch.__version__ == 'parrots': + TORCH_VERSION = torch.__version__ +else: + # torch.__version__ could be 1.3.1+cu92, we only need the first two + # for comparison + TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) + + +def obsolete_torch_version(torch_version, version_threshold): + return torch_version == 'parrots' or torch_version <= version_threshold + + +class NewEmptyTensorOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, new_shape): + ctx.shape = x.shape + return x.new_empty(new_shape) + + @staticmethod + def backward(ctx, grad): + shape = ctx.shape + return NewEmptyTensorOp.apply(grad, shape), None + + +@CONV_LAYERS.register_module('Conv', force=True) +class Conv2d(nn.Conv2d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, + self.padding, self.stride, self.dilation): + o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1 + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +@CONV_LAYERS.register_module('Conv3d', force=True) +class Conv3d(nn.Conv3d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, + self.padding, self.stride, self.dilation): + o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1 + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +@CONV_LAYERS.register_module() +@CONV_LAYERS.register_module('deconv') +@UPSAMPLE_LAYERS.register_module('deconv', force=True) +class ConvTranspose2d(nn.ConvTranspose2d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, + self.padding, self.stride, + self.dilation, self.output_padding): + out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +@CONV_LAYERS.register_module() +@CONV_LAYERS.register_module('deconv3d') +@UPSAMPLE_LAYERS.register_module('deconv3d', force=True) +class ConvTranspose3d(nn.ConvTranspose3d): + + def forward(self, x): + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): + out_shape = [x.shape[0], self.out_channels] + for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, + self.padding, self.stride, + self.dilation, self.output_padding): + out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op) + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) + + +class MaxPool2d(nn.MaxPool2d): + + def forward(self, x): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + out_shape = list(x.shape[:2]) + for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), + _pair(self.padding), _pair(self.stride), + _pair(self.dilation)): + o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1 + o = math.ceil(o) if self.ceil_mode else math.floor(o) + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + return empty + + return super().forward(x) + + +class MaxPool3d(nn.MaxPool3d): + + def forward(self, x): + # PyTorch 1.9 does not support empty tensor inference yet + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): + out_shape = list(x.shape[:2]) + for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size), + _triple(self.padding), + _triple(self.stride), + _triple(self.dilation)): + o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1 + o = math.ceil(o) if self.ceil_mode else math.floor(o) + out_shape.append(o) + empty = NewEmptyTensorOp.apply(x, out_shape) + return empty + + return super().forward(x) + + +class Linear(torch.nn.Linear): + + def forward(self, x): + # empty tensor forward of Linear layer is supported in Pytorch 1.6 + if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)): + out_shape = [x.shape[0], self.out_features] + empty = NewEmptyTensorOp.apply(x, out_shape) + if self.training: + # produce dummy gradient to avoid DDP warning. + dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + dummy + else: + return empty + + return super().forward(x) diff --git a/annotator/uniformer/mmcv/cnn/builder.py b/annotator/uniformer/mmcv/cnn/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..7567316c566bd3aca6d8f65a84b00e9e890948a7 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/builder.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..runner import Sequential +from ..utils import Registry, build_from_cfg + + +def build_model_from_cfg(cfg, registry, default_args=None): + """Build a PyTorch model from config dict(s). Different from + ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built. + + Args: + cfg (dict, list[dict]): The config of modules, is is either a config + dict or a list of config dicts. If cfg is a list, a + the built modules will be wrapped with ``nn.Sequential``. + registry (:obj:`Registry`): A registry the module belongs to. + default_args (dict, optional): Default arguments to build the module. + Defaults to None. + + Returns: + nn.Module: A built nn module. + """ + if isinstance(cfg, list): + modules = [ + build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg + ] + return Sequential(*modules) + else: + return build_from_cfg(cfg, registry, default_args) + + +MODELS = Registry('model', build_func=build_model_from_cfg) diff --git a/annotator/uniformer/mmcv/cnn/resnet.py b/annotator/uniformer/mmcv/cnn/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb3ac057ee2d52c46fc94685b5d4e698aad8d5f --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/resnet.py @@ -0,0 +1,316 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +import torch.nn as nn +import torch.utils.checkpoint as cp + +from .utils import constant_init, kaiming_init + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + """3x3 convolution with padding.""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False): + super(BasicBlock, self).__init__() + assert style in ['pytorch', 'caffe'] + self.conv1 = conv3x3(inplanes, planes, stride, dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + assert not with_cp + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False): + """Bottleneck block. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__() + assert style in ['pytorch', 'caffe'] + if style == 'pytorch': + conv1_stride = 1 + conv2_stride = stride + else: + conv1_stride = stride + conv2_stride = 1 + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + def forward(self, x): + + def _inner_forward(x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +def make_res_layer(block, + inplanes, + planes, + blocks, + stride=1, + dilation=1, + style='pytorch', + with_cp=False): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + inplanes, + planes, + stride, + dilation, + downsample, + style=style, + with_cp=with_cp)) + inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp)) + + return nn.Sequential(*layers) + + +class ResNet(nn.Module): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + num_stages (int): Resnet stages, normally 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze + running stats (mean and var). + bn_frozen (bool): Whether to freeze weight and bias of BN layers. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + frozen_stages=-1, + bn_eval=True, + bn_frozen=False, + with_cp=False): + super(ResNet, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + assert num_stages >= 1 and num_stages <= 4 + block, stage_blocks = self.arch_settings[depth] + stage_blocks = stage_blocks[:num_stages] + assert len(strides) == len(dilations) == num_stages + assert max(out_indices) < num_stages + + self.out_indices = out_indices + self.style = style + self.frozen_stages = frozen_stages + self.bn_eval = bn_eval + self.bn_frozen = bn_frozen + self.with_cp = with_cp + + self.inplanes = 64 + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.res_layers = [] + for i, num_blocks in enumerate(stage_blocks): + stride = strides[i] + dilation = dilations[i] + planes = 64 * 2**i + res_layer = make_res_layer( + block, + self.inplanes, + planes, + num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + with_cp=with_cp) + self.inplanes = planes * block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + from ..runner import load_checkpoint + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def train(self, mode=True): + super(ResNet, self).train(mode) + if self.bn_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + if self.bn_frozen: + for params in m.parameters(): + params.requires_grad = False + if mode and self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for param in self.bn1.parameters(): + param.requires_grad = False + self.bn1.eval() + self.bn1.weight.requires_grad = False + self.bn1.bias.requires_grad = False + for i in range(1, self.frozen_stages + 1): + mod = getattr(self, f'layer{i}') + mod.eval() + for param in mod.parameters(): + param.requires_grad = False diff --git a/annotator/uniformer/mmcv/cnn/utils/__init__.py b/annotator/uniformer/mmcv/cnn/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a263e31c1e3977712827ca229bbc04910b4e928e --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/utils/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .flops_counter import get_model_complexity_info +from .fuse_conv_bn import fuse_conv_bn +from .sync_bn import revert_sync_batchnorm +from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit, + KaimingInit, NormalInit, PretrainedInit, + TruncNormalInit, UniformInit, XavierInit, + bias_init_with_prob, caffe2_xavier_init, + constant_init, initialize, kaiming_init, normal_init, + trunc_normal_init, uniform_init, xavier_init) + +__all__ = [ + 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init', + 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init', + 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize', + 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit', + 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', + 'Caffe2XavierInit', 'revert_sync_batchnorm' +] diff --git a/annotator/uniformer/mmcv/cnn/utils/flops_counter.py b/annotator/uniformer/mmcv/cnn/utils/flops_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..d10af5feca7f4b8c0ba359b7b1c826f754e048be --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/utils/flops_counter.py @@ -0,0 +1,599 @@ +# Modified from flops-counter.pytorch by Vladislav Sovrasov +# original repo: https://github.com/sovrasov/flops-counter.pytorch + +# MIT License + +# Copyright (c) 2018 Vladislav Sovrasov + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import sys +from functools import partial + +import numpy as np +import torch +import torch.nn as nn + +import annotator.uniformer.mmcv as mmcv + + +def get_model_complexity_info(model, + input_shape, + print_per_layer_stat=True, + as_strings=True, + input_constructor=None, + flush=False, + ost=sys.stdout): + """Get complexity information of a model. + + This method can calculate FLOPs and parameter counts of a model with + corresponding input shape. It can also print complexity information for + each layer in a model. + + Supported layers are listed as below: + - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``. + - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``, + ``nn.ReLU6``. + - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``, + ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``, + ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``, + ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``, + ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``. + - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``, + ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``, + ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``. + - Linear: ``nn.Linear``. + - Deconvolution: ``nn.ConvTranspose2d``. + - Upsample: ``nn.Upsample``. + + Args: + model (nn.Module): The model for complexity calculation. + input_shape (tuple): Input shape used for calculation. + print_per_layer_stat (bool): Whether to print complexity information + for each layer in a model. Default: True. + as_strings (bool): Output FLOPs and params counts in a string form. + Default: True. + input_constructor (None | callable): If specified, it takes a callable + method that generates input. otherwise, it will generate a random + tensor with input shape to calculate FLOPs. Default: None. + flush (bool): same as that in :func:`print`. Default: False. + ost (stream): same as ``file`` param in :func:`print`. + Default: sys.stdout. + + Returns: + tuple[float | str]: If ``as_strings`` is set to True, it will return + FLOPs and parameter counts in a string format. otherwise, it will + return those in a float number format. + """ + assert type(input_shape) is tuple + assert len(input_shape) >= 1 + assert isinstance(model, nn.Module) + flops_model = add_flops_counting_methods(model) + flops_model.eval() + flops_model.start_flops_count() + if input_constructor: + input = input_constructor(input_shape) + _ = flops_model(**input) + else: + try: + batch = torch.ones(()).new_empty( + (1, *input_shape), + dtype=next(flops_model.parameters()).dtype, + device=next(flops_model.parameters()).device) + except StopIteration: + # Avoid StopIteration for models which have no parameters, + # like `nn.Relu()`, `nn.AvgPool2d`, etc. + batch = torch.ones(()).new_empty((1, *input_shape)) + + _ = flops_model(batch) + + flops_count, params_count = flops_model.compute_average_flops_cost() + if print_per_layer_stat: + print_model_with_flops( + flops_model, flops_count, params_count, ost=ost, flush=flush) + flops_model.stop_flops_count() + + if as_strings: + return flops_to_string(flops_count), params_to_string(params_count) + + return flops_count, params_count + + +def flops_to_string(flops, units='GFLOPs', precision=2): + """Convert FLOPs number into a string. + + Note that Here we take a multiply-add counts as one FLOP. + + Args: + flops (float): FLOPs number to be converted. + units (str | None): Converted FLOPs units. Options are None, 'GFLOPs', + 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically + choose the most suitable unit for FLOPs. Default: 'GFLOPs'. + precision (int): Digit number after the decimal point. Default: 2. + + Returns: + str: The converted FLOPs number with units. + + Examples: + >>> flops_to_string(1e9) + '1.0 GFLOPs' + >>> flops_to_string(2e5, 'MFLOPs') + '0.2 MFLOPs' + >>> flops_to_string(3e-9, None) + '3e-09 FLOPs' + """ + if units is None: + if flops // 10**9 > 0: + return str(round(flops / 10.**9, precision)) + ' GFLOPs' + elif flops // 10**6 > 0: + return str(round(flops / 10.**6, precision)) + ' MFLOPs' + elif flops // 10**3 > 0: + return str(round(flops / 10.**3, precision)) + ' KFLOPs' + else: + return str(flops) + ' FLOPs' + else: + if units == 'GFLOPs': + return str(round(flops / 10.**9, precision)) + ' ' + units + elif units == 'MFLOPs': + return str(round(flops / 10.**6, precision)) + ' ' + units + elif units == 'KFLOPs': + return str(round(flops / 10.**3, precision)) + ' ' + units + else: + return str(flops) + ' FLOPs' + + +def params_to_string(num_params, units=None, precision=2): + """Convert parameter number into a string. + + Args: + num_params (float): Parameter number to be converted. + units (str | None): Converted FLOPs units. Options are None, 'M', + 'K' and ''. If set to None, it will automatically choose the most + suitable unit for Parameter number. Default: None. + precision (int): Digit number after the decimal point. Default: 2. + + Returns: + str: The converted parameter number with units. + + Examples: + >>> params_to_string(1e9) + '1000.0 M' + >>> params_to_string(2e5) + '200.0 k' + >>> params_to_string(3e-9) + '3e-09' + """ + if units is None: + if num_params // 10**6 > 0: + return str(round(num_params / 10**6, precision)) + ' M' + elif num_params // 10**3: + return str(round(num_params / 10**3, precision)) + ' k' + else: + return str(num_params) + else: + if units == 'M': + return str(round(num_params / 10.**6, precision)) + ' ' + units + elif units == 'K': + return str(round(num_params / 10.**3, precision)) + ' ' + units + else: + return str(num_params) + + +def print_model_with_flops(model, + total_flops, + total_params, + units='GFLOPs', + precision=3, + ost=sys.stdout, + flush=False): + """Print a model with FLOPs for each layer. + + Args: + model (nn.Module): The model to be printed. + total_flops (float): Total FLOPs of the model. + total_params (float): Total parameter counts of the model. + units (str | None): Converted FLOPs units. Default: 'GFLOPs'. + precision (int): Digit number after the decimal point. Default: 3. + ost (stream): same as `file` param in :func:`print`. + Default: sys.stdout. + flush (bool): same as that in :func:`print`. Default: False. + + Example: + >>> class ExampleModel(nn.Module): + + >>> def __init__(self): + >>> super().__init__() + >>> self.conv1 = nn.Conv2d(3, 8, 3) + >>> self.conv2 = nn.Conv2d(8, 256, 3) + >>> self.conv3 = nn.Conv2d(256, 8, 3) + >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + >>> self.flatten = nn.Flatten() + >>> self.fc = nn.Linear(8, 1) + + >>> def forward(self, x): + >>> x = self.conv1(x) + >>> x = self.conv2(x) + >>> x = self.conv3(x) + >>> x = self.avg_pool(x) + >>> x = self.flatten(x) + >>> x = self.fc(x) + >>> return x + + >>> model = ExampleModel() + >>> x = (3, 16, 16) + to print the complexity information state for each layer, you can use + >>> get_model_complexity_info(model, x) + or directly use + >>> print_model_with_flops(model, 4579784.0, 37361) + ExampleModel( + 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs, + (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501 + (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1)) + (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1)) + (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1)) + (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, ) + (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True) + ) + """ + + def accumulate_params(self): + if is_supported_instance(self): + return self.__params__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_params() + return sum + + def accumulate_flops(self): + if is_supported_instance(self): + return self.__flops__ / model.__batch_counter__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_flops() + return sum + + def flops_repr(self): + accumulated_num_params = self.accumulate_params() + accumulated_flops_cost = self.accumulate_flops() + return ', '.join([ + params_to_string( + accumulated_num_params, units='M', precision=precision), + '{:.3%} Params'.format(accumulated_num_params / total_params), + flops_to_string( + accumulated_flops_cost, units=units, precision=precision), + '{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops), + self.original_extra_repr() + ]) + + def add_extra_repr(m): + m.accumulate_flops = accumulate_flops.__get__(m) + m.accumulate_params = accumulate_params.__get__(m) + flops_extra_repr = flops_repr.__get__(m) + if m.extra_repr != flops_extra_repr: + m.original_extra_repr = m.extra_repr + m.extra_repr = flops_extra_repr + assert m.extra_repr != m.original_extra_repr + + def del_extra_repr(m): + if hasattr(m, 'original_extra_repr'): + m.extra_repr = m.original_extra_repr + del m.original_extra_repr + if hasattr(m, 'accumulate_flops'): + del m.accumulate_flops + + model.apply(add_extra_repr) + print(model, file=ost, flush=flush) + model.apply(del_extra_repr) + + +def get_model_parameters_number(model): + """Calculate parameter number of a model. + + Args: + model (nn.module): The model for parameter number calculation. + + Returns: + float: Parameter number of the model. + """ + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return num_params + + +def add_flops_counting_methods(net_main_module): + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + net_main_module.start_flops_count = start_flops_count.__get__( + net_main_module) + net_main_module.stop_flops_count = stop_flops_count.__get__( + net_main_module) + net_main_module.reset_flops_count = reset_flops_count.__get__( + net_main_module) + net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501 + net_main_module) + + net_main_module.reset_flops_count() + + return net_main_module + + +def compute_average_flops_cost(self): + """Compute average FLOPs cost. + + A method to compute average FLOPs cost, which will be available after + `add_flops_counting_methods()` is called on a desired net object. + + Returns: + float: Current mean flops consumption per image. + """ + batches_count = self.__batch_counter__ + flops_sum = 0 + for module in self.modules(): + if is_supported_instance(module): + flops_sum += module.__flops__ + params_sum = get_model_parameters_number(self) + return flops_sum / batches_count, params_sum + + +def start_flops_count(self): + """Activate the computation of mean flops consumption per image. + + A method to activate the computation of mean flops consumption per image. + which will be available after ``add_flops_counting_methods()`` is called on + a desired net object. It should be called before running the network. + """ + add_batch_counter_hook_function(self) + + def add_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + return + + else: + handle = module.register_forward_hook( + get_modules_mapping()[type(module)]) + + module.__flops_handle__ = handle + + self.apply(partial(add_flops_counter_hook_function)) + + +def stop_flops_count(self): + """Stop computing the mean flops consumption per image. + + A method to stop computing the mean flops consumption per image, which will + be available after ``add_flops_counting_methods()`` is called on a desired + net object. It can be called to pause the computation whenever. + """ + remove_batch_counter_hook_function(self) + self.apply(remove_flops_counter_hook_function) + + +def reset_flops_count(self): + """Reset statistics computed so far. + + A method to Reset computed statistics, which will be available after + `add_flops_counting_methods()` is called on a desired net object. + """ + add_batch_counter_variables_or_reset(self) + self.apply(add_flops_counter_variable_or_reset) + + +# ---- Internal functions +def empty_flops_counter_hook(module, input, output): + module.__flops__ += 0 + + +def upsample_flops_counter_hook(module, input, output): + output_size = output[0] + batch_size = output_size.shape[0] + output_elements_count = batch_size + for val in output_size.shape[1:]: + output_elements_count *= val + module.__flops__ += int(output_elements_count) + + +def relu_flops_counter_hook(module, input, output): + active_elements_count = output.numel() + module.__flops__ += int(active_elements_count) + + +def linear_flops_counter_hook(module, input, output): + input = input[0] + output_last_dim = output.shape[ + -1] # pytorch checks dimensions, so here we don't care much + module.__flops__ += int(np.prod(input.shape) * output_last_dim) + + +def pool_flops_counter_hook(module, input, output): + input = input[0] + module.__flops__ += int(np.prod(input.shape)) + + +def norm_flops_counter_hook(module, input, output): + input = input[0] + + batch_flops = np.prod(input.shape) + if (getattr(module, 'affine', False) + or getattr(module, 'elementwise_affine', False)): + batch_flops *= 2 + module.__flops__ += int(batch_flops) + + +def deconv_flops_counter_hook(conv_module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + + batch_size = input.shape[0] + input_height, input_width = input.shape[2:] + + kernel_height, kernel_width = conv_module.kernel_size + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = ( + kernel_height * kernel_width * in_channels * filters_per_channel) + + active_elements_count = batch_size * input_height * input_width + overall_conv_flops = conv_per_position_flops * active_elements_count + bias_flops = 0 + if conv_module.bias is not None: + output_height, output_width = output.shape[2:] + bias_flops = out_channels * batch_size * output_height * output_height + overall_flops = overall_conv_flops + bias_flops + + conv_module.__flops__ += int(overall_flops) + + +def conv_flops_counter_hook(conv_module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + + batch_size = input.shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(conv_module.kernel_size) + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = int( + np.prod(kernel_dims)) * in_channels * filters_per_channel + + active_elements_count = batch_size * int(np.prod(output_dims)) + + overall_conv_flops = conv_per_position_flops * active_elements_count + + bias_flops = 0 + + if conv_module.bias is not None: + + bias_flops = out_channels * active_elements_count + + overall_flops = overall_conv_flops + bias_flops + + conv_module.__flops__ += int(overall_flops) + + +def batch_counter_hook(module, input, output): + batch_size = 1 + if len(input) > 0: + # Can have multiple inputs, getting the first one + input = input[0] + batch_size = len(input) + else: + pass + print('Warning! No positional inputs found for a module, ' + 'assuming batch size is 1.') + module.__batch_counter__ += batch_size + + +def add_batch_counter_variables_or_reset(module): + + module.__batch_counter__ = 0 + + +def add_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + return + + handle = module.register_forward_hook(batch_counter_hook) + module.__batch_counter_handle__ = handle + + +def remove_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + module.__batch_counter_handle__.remove() + del module.__batch_counter_handle__ + + +def add_flops_counter_variable_or_reset(module): + if is_supported_instance(module): + if hasattr(module, '__flops__') or hasattr(module, '__params__'): + print('Warning: variables __flops__ or __params__ are already ' + 'defined for the module' + type(module).__name__ + + ' ptflops can affect your code!') + module.__flops__ = 0 + module.__params__ = get_model_parameters_number(module) + + +def is_supported_instance(module): + if type(module) in get_modules_mapping(): + return True + return False + + +def remove_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + module.__flops_handle__.remove() + del module.__flops_handle__ + + +def get_modules_mapping(): + return { + # convolutions + nn.Conv1d: conv_flops_counter_hook, + nn.Conv2d: conv_flops_counter_hook, + mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook, + nn.Conv3d: conv_flops_counter_hook, + mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook, + # activations + nn.ReLU: relu_flops_counter_hook, + nn.PReLU: relu_flops_counter_hook, + nn.ELU: relu_flops_counter_hook, + nn.LeakyReLU: relu_flops_counter_hook, + nn.ReLU6: relu_flops_counter_hook, + # poolings + nn.MaxPool1d: pool_flops_counter_hook, + nn.AvgPool1d: pool_flops_counter_hook, + nn.AvgPool2d: pool_flops_counter_hook, + nn.MaxPool2d: pool_flops_counter_hook, + mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook, + nn.MaxPool3d: pool_flops_counter_hook, + mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook, + nn.AvgPool3d: pool_flops_counter_hook, + nn.AdaptiveMaxPool1d: pool_flops_counter_hook, + nn.AdaptiveAvgPool1d: pool_flops_counter_hook, + nn.AdaptiveMaxPool2d: pool_flops_counter_hook, + nn.AdaptiveAvgPool2d: pool_flops_counter_hook, + nn.AdaptiveMaxPool3d: pool_flops_counter_hook, + nn.AdaptiveAvgPool3d: pool_flops_counter_hook, + # normalizations + nn.BatchNorm1d: norm_flops_counter_hook, + nn.BatchNorm2d: norm_flops_counter_hook, + nn.BatchNorm3d: norm_flops_counter_hook, + nn.GroupNorm: norm_flops_counter_hook, + nn.InstanceNorm1d: norm_flops_counter_hook, + nn.InstanceNorm2d: norm_flops_counter_hook, + nn.InstanceNorm3d: norm_flops_counter_hook, + nn.LayerNorm: norm_flops_counter_hook, + # FC + nn.Linear: linear_flops_counter_hook, + mmcv.cnn.bricks.Linear: linear_flops_counter_hook, + # Upscale + nn.Upsample: upsample_flops_counter_hook, + # Deconvolution + nn.ConvTranspose2d: deconv_flops_counter_hook, + mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook, + } diff --git a/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py b/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..cb7076f80bf37f7931185bf0293ffcc1ce19c8ef --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +def _fuse_conv_bn(conv, bn): + """Fuse conv and bn into one module. + + Args: + conv (nn.Module): Conv to be fused. + bn (nn.Module): BN to be fused. + + Returns: + nn.Module: Fused module. + """ + conv_w = conv.weight + conv_b = conv.bias if conv.bias is not None else torch.zeros_like( + bn.running_mean) + + factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) + conv.weight = nn.Parameter(conv_w * + factor.reshape([conv.out_channels, 1, 1, 1])) + conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) + return conv + + +def fuse_conv_bn(module): + """Recursively fuse conv and bn in a module. + + During inference, the functionary of batch norm layers is turned off + but only the mean and var alone channels are used, which exposes the + chance to fuse it with the preceding conv layers to save computations and + simplify network structures. + + Args: + module (nn.Module): Module to be fused. + + Returns: + nn.Module: Fused module. + """ + last_conv = None + last_conv_name = None + + for name, child in module.named_children(): + if isinstance(child, + (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)): + if last_conv is None: # only fuse BN that is after Conv + continue + fused_conv = _fuse_conv_bn(last_conv, child) + module._modules[last_conv_name] = fused_conv + # To reduce changes, set BN as Identity instead of deleting it. + module._modules[name] = nn.Identity() + last_conv = None + elif isinstance(child, nn.Conv2d): + last_conv = child + last_conv_name = name + else: + fuse_conv_bn(child) + return module diff --git a/annotator/uniformer/mmcv/cnn/utils/sync_bn.py b/annotator/uniformer/mmcv/cnn/utils/sync_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..f78f39181d75bb85c53e8c7c8eaf45690e9f0bee --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/utils/sync_bn.py @@ -0,0 +1,59 @@ +import torch + +import annotator.uniformer.mmcv as mmcv + + +class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): + """A general BatchNorm layer without input dimension check. + + Reproduced from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc + is `_check_input_dim` that is designed for tensor sanity checks. + The check has been bypassed in this class for the convenience of converting + SyncBatchNorm. + """ + + def _check_input_dim(self, input): + return + + +def revert_sync_batchnorm(module): + """Helper function to convert all `SyncBatchNorm` (SyncBN) and + `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to + `BatchNormXd` layers. + + Adapted from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + + Args: + module (nn.Module): The module containing `SyncBatchNorm` layers. + + Returns: + module_output: The converted module with `BatchNormXd` layers. + """ + module_output = module + module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm] + if hasattr(mmcv, 'ops'): + module_checklist.append(mmcv.ops.SyncBatchNorm) + if isinstance(module, tuple(module_checklist)): + module_output = _BatchNormXd(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + # no_grad() may not be needed here but + # just to be consistent with `convert_sync_batchnorm()` + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + module_output.training = module.training + # qconfig exists in quantized models + if hasattr(module, 'qconfig'): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, revert_sync_batchnorm(child)) + del module + return module_output diff --git a/annotator/uniformer/mmcv/cnn/utils/weight_init.py b/annotator/uniformer/mmcv/cnn/utils/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..287a1d0bffe26e023029d48634d9b761deda7ba4 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/utils/weight_init.py @@ -0,0 +1,684 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +import warnings + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor + +from annotator.uniformer.mmcv.utils import Registry, build_from_cfg, get_logger, print_log + +INITIALIZERS = Registry('initializer') + + +def update_init_info(module, init_info): + """Update the `_params_init_info` in the module if the value of parameters + are changed. + + Args: + module (obj:`nn.Module`): The module of PyTorch with a user-defined + attribute `_params_init_info` which records the initialization + information. + init_info (str): The string that describes the initialization. + """ + assert hasattr( + module, + '_params_init_info'), f'Can not find `_params_init_info` in {module}' + for name, param in module.named_parameters(): + + assert param in module._params_init_info, ( + f'Find a new :obj:`Parameter` ' + f'named `{name}` during executing the ' + f'`init_weights` of ' + f'`{module.__class__.__name__}`. ' + f'Please do not add or ' + f'replace parameters during executing ' + f'the `init_weights`. ') + + # The parameter has been changed during executing the + # `init_weights` of module + mean_value = param.data.mean() + if module._params_init_info[param]['tmp_mean_value'] != mean_value: + module._params_init_info[param]['init_info'] = init_info + module._params_init_info[param]['tmp_mean_value'] = mean_value + + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def xavier_init(module, gain=1, bias=0, distribution='normal'): + assert distribution in ['uniform', 'normal'] + if hasattr(module, 'weight') and module.weight is not None: + if distribution == 'uniform': + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def normal_init(module, mean=0, std=1, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def trunc_normal_init(module: nn.Module, + mean: float = 0, + std: float = 1, + a: float = -2, + b: float = 2, + bias: float = 0) -> None: + if hasattr(module, 'weight') and module.weight is not None: + trunc_normal_(module.weight, mean, std, a, b) # type: ignore + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) # type: ignore + + +def uniform_init(module, a=0, b=1, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.uniform_(module.weight, a, b) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def kaiming_init(module, + a=0, + mode='fan_out', + nonlinearity='relu', + bias=0, + distribution='normal'): + assert distribution in ['uniform', 'normal'] + if hasattr(module, 'weight') and module.weight is not None: + if distribution == 'uniform': + nn.init.kaiming_uniform_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def caffe2_xavier_init(module, bias=0): + # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch + # Acknowledgment to FAIR's internal code + kaiming_init( + module, + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + bias=bias, + distribution='uniform') + + +def bias_init_with_prob(prior_prob): + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + return bias_init + + +def _get_bases_name(m): + return [b.__name__ for b in m.__class__.__bases__] + + +class BaseInit(object): + + def __init__(self, *, bias=0, bias_prob=None, layer=None): + self.wholemodule = False + if not isinstance(bias, (int, float)): + raise TypeError(f'bias must be a number, but got a {type(bias)}') + + if bias_prob is not None: + if not isinstance(bias_prob, float): + raise TypeError(f'bias_prob type must be float, \ + but got {type(bias_prob)}') + + if layer is not None: + if not isinstance(layer, (str, list)): + raise TypeError(f'layer must be a str or a list of str, \ + but got a {type(layer)}') + else: + layer = [] + + if bias_prob is not None: + self.bias = bias_init_with_prob(bias_prob) + else: + self.bias = bias + self.layer = [layer] if isinstance(layer, str) else layer + + def _get_init_info(self): + info = f'{self.__class__.__name__}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Constant') +class ConstantInit(BaseInit): + """Initialize module parameters with constant values. + + Args: + val (int | float): the value to fill the weights in the module with + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, val, **kwargs): + super().__init__(**kwargs) + self.val = val + + def __call__(self, module): + + def init(m): + if self.wholemodule: + constant_init(m, self.val, self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + constant_init(m, self.val, self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Xavier') +class XavierInit(BaseInit): + r"""Initialize module parameters with values according to the method + described in `Understanding the difficulty of training deep feedforward + neural networks - Glorot, X. & Bengio, Y. (2010). + `_ + + Args: + gain (int | float): an optional scaling factor. Defaults to 1. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + distribution (str): distribution either be ``'normal'`` + or ``'uniform'``. Defaults to ``'normal'``. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, gain=1, distribution='normal', **kwargs): + super().__init__(**kwargs) + self.gain = gain + self.distribution = distribution + + def __call__(self, module): + + def init(m): + if self.wholemodule: + xavier_init(m, self.gain, self.bias, self.distribution) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + xavier_init(m, self.gain, self.bias, self.distribution) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: gain={self.gain}, ' \ + f'distribution={self.distribution}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Normal') +class NormalInit(BaseInit): + r"""Initialize module parameters with the values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + + Args: + mean (int | float):the mean of the normal distribution. Defaults to 0. + std (int | float): the standard deviation of the normal distribution. + Defaults to 1. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + + """ + + def __init__(self, mean=0, std=1, **kwargs): + super().__init__(**kwargs) + self.mean = mean + self.std = std + + def __call__(self, module): + + def init(m): + if self.wholemodule: + normal_init(m, self.mean, self.std, self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + normal_init(m, self.mean, self.std, self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: mean={self.mean},' \ + f' std={self.std}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='TruncNormal') +class TruncNormalInit(BaseInit): + r"""Initialize module parameters with the values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values + outside :math:`[a, b]`. + + Args: + mean (float): the mean of the normal distribution. Defaults to 0. + std (float): the standard deviation of the normal distribution. + Defaults to 1. + a (float): The minimum cutoff value. + b ( float): The maximum cutoff value. + bias (float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + + """ + + def __init__(self, + mean: float = 0, + std: float = 1, + a: float = -2, + b: float = 2, + **kwargs) -> None: + super().__init__(**kwargs) + self.mean = mean + self.std = std + self.a = a + self.b = b + + def __call__(self, module: nn.Module) -> None: + + def init(m): + if self.wholemodule: + trunc_normal_init(m, self.mean, self.std, self.a, self.b, + self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + trunc_normal_init(m, self.mean, self.std, self.a, self.b, + self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \ + f' mean={self.mean}, std={self.std}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Uniform') +class UniformInit(BaseInit): + r"""Initialize module parameters with values drawn from the uniform + distribution :math:`\mathcal{U}(a, b)`. + + Args: + a (int | float): the lower bound of the uniform distribution. + Defaults to 0. + b (int | float): the upper bound of the uniform distribution. + Defaults to 1. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, a=0, b=1, **kwargs): + super().__init__(**kwargs) + self.a = a + self.b = b + + def __call__(self, module): + + def init(m): + if self.wholemodule: + uniform_init(m, self.a, self.b, self.bias) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + uniform_init(m, self.a, self.b, self.bias) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: a={self.a},' \ + f' b={self.b}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Kaiming') +class KaimingInit(BaseInit): + r"""Initialize module parameters with the values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification - He, K. et al. (2015). + `_ + + Args: + a (int | float): the negative slope of the rectifier used after this + layer (only used with ``'leaky_relu'``). Defaults to 0. + mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing + ``'fan_in'`` preserves the magnitude of the variance of the weights + in the forward pass. Choosing ``'fan_out'`` preserves the + magnitudes in the backwards pass. Defaults to ``'fan_out'``. + nonlinearity (str): the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` . + Defaults to 'relu'. + bias (int | float): the value to fill the bias. Defaults to 0. + bias_prob (float, optional): the probability for bias initialization. + Defaults to None. + distribution (str): distribution either be ``'normal'`` or + ``'uniform'``. Defaults to ``'normal'``. + layer (str | list[str], optional): the layer will be initialized. + Defaults to None. + """ + + def __init__(self, + a=0, + mode='fan_out', + nonlinearity='relu', + distribution='normal', + **kwargs): + super().__init__(**kwargs) + self.a = a + self.mode = mode + self.nonlinearity = nonlinearity + self.distribution = distribution + + def __call__(self, module): + + def init(m): + if self.wholemodule: + kaiming_init(m, self.a, self.mode, self.nonlinearity, + self.bias, self.distribution) + else: + layername = m.__class__.__name__ + basesname = _get_bases_name(m) + if len(set(self.layer) & set([layername] + basesname)): + kaiming_init(m, self.a, self.mode, self.nonlinearity, + self.bias, self.distribution) + + module.apply(init) + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \ + f'nonlinearity={self.nonlinearity}, ' \ + f'distribution ={self.distribution}, bias={self.bias}' + return info + + +@INITIALIZERS.register_module(name='Caffe2Xavier') +class Caffe2XavierInit(KaimingInit): + # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch + # Acknowledgment to FAIR's internal code + def __init__(self, **kwargs): + super().__init__( + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + distribution='uniform', + **kwargs) + + def __call__(self, module): + super().__call__(module) + + +@INITIALIZERS.register_module(name='Pretrained') +class PretrainedInit(object): + """Initialize module by loading a pretrained model. + + Args: + checkpoint (str): the checkpoint file of the pretrained model should + be load. + prefix (str, optional): the prefix of a sub-module in the pretrained + model. it is for loading a part of the pretrained model to + initialize. For example, if we would like to only load the + backbone of a detector model, we can set ``prefix='backbone.'``. + Defaults to None. + map_location (str): map tensors into proper locations. + """ + + def __init__(self, checkpoint, prefix=None, map_location=None): + self.checkpoint = checkpoint + self.prefix = prefix + self.map_location = map_location + + def __call__(self, module): + from annotator.uniformer.mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint, + load_state_dict) + logger = get_logger('mmcv') + if self.prefix is None: + print_log(f'load model from: {self.checkpoint}', logger=logger) + load_checkpoint( + module, + self.checkpoint, + map_location=self.map_location, + strict=False, + logger=logger) + else: + print_log( + f'load {self.prefix} in model from: {self.checkpoint}', + logger=logger) + state_dict = _load_checkpoint_with_prefix( + self.prefix, self.checkpoint, map_location=self.map_location) + load_state_dict(module, state_dict, strict=False, logger=logger) + + if hasattr(module, '_params_init_info'): + update_init_info(module, init_info=self._get_init_info()) + + def _get_init_info(self): + info = f'{self.__class__.__name__}: load from {self.checkpoint}' + return info + + +def _initialize(module, cfg, wholemodule=False): + func = build_from_cfg(cfg, INITIALIZERS) + # wholemodule flag is for override mode, there is no layer key in override + # and initializer will give init values for the whole module with the name + # in override. + func.wholemodule = wholemodule + func(module) + + +def _initialize_override(module, override, cfg): + if not isinstance(override, (dict, list)): + raise TypeError(f'override must be a dict or a list of dict, \ + but got {type(override)}') + + override = [override] if isinstance(override, dict) else override + + for override_ in override: + + cp_override = copy.deepcopy(override_) + name = cp_override.pop('name', None) + if name is None: + raise ValueError('`override` must contain the key "name",' + f'but got {cp_override}') + # if override only has name key, it means use args in init_cfg + if not cp_override: + cp_override.update(cfg) + # if override has name key and other args except type key, it will + # raise error + elif 'type' not in cp_override.keys(): + raise ValueError( + f'`override` need "type" key, but got {cp_override}') + + if hasattr(module, name): + _initialize(getattr(module, name), cp_override, wholemodule=True) + else: + raise RuntimeError(f'module did not have attribute {name}, ' + f'but init_cfg is {cp_override}.') + + +def initialize(module, init_cfg): + """Initialize a module. + + Args: + module (``torch.nn.Module``): the module will be initialized. + init_cfg (dict | list[dict]): initialization configuration dict to + define initializer. OpenMMLab has implemented 6 initializers + including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``, + ``Kaiming``, and ``Pretrained``. + Example: + >>> module = nn.Linear(2, 3, bias=True) + >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2) + >>> initialize(module, init_cfg) + + >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2)) + >>> # define key ``'layer'`` for initializing layer with different + >>> # configuration + >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1), + dict(type='Constant', layer='Linear', val=2)] + >>> initialize(module, init_cfg) + + >>> # define key``'override'`` to initialize some specific part in + >>> # module + >>> class FooNet(nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.feat = nn.Conv2d(3, 16, 3) + >>> self.reg = nn.Conv2d(16, 10, 3) + >>> self.cls = nn.Conv2d(16, 5, 3) + >>> model = FooNet() + >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d', + >>> override=dict(type='Constant', name='reg', val=3, bias=4)) + >>> initialize(model, init_cfg) + + >>> model = ResNet(depth=50) + >>> # Initialize weights with the pretrained model. + >>> init_cfg = dict(type='Pretrained', + checkpoint='torchvision://resnet50') + >>> initialize(model, init_cfg) + + >>> # Initialize weights of a sub-module with the specific part of + >>> # a pretrained model by using "prefix". + >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\ + >>> 'retinanet_r50_fpn_1x_coco/'\ + >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' + >>> init_cfg = dict(type='Pretrained', + checkpoint=url, prefix='backbone.') + """ + if not isinstance(init_cfg, (dict, list)): + raise TypeError(f'init_cfg must be a dict or a list of dict, \ + but got {type(init_cfg)}') + + if isinstance(init_cfg, dict): + init_cfg = [init_cfg] + + for cfg in init_cfg: + # should deeply copy the original config because cfg may be used by + # other modules, e.g., one init_cfg shared by multiple bottleneck + # blocks, the expected cfg will be changed after pop and will change + # the initialization behavior of other modules + cp_cfg = copy.deepcopy(cfg) + override = cp_cfg.pop('override', None) + _initialize(module, cp_cfg) + + if override is not None: + cp_cfg.pop('layer', None) + _initialize_override(module, override, cp_cfg) + else: + # All attributes in module have same initialization. + pass + + +def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float, + b: float) -> Tensor: + # Method based on + # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + # Modified from + # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + lower = norm_cdf((a - mean) / std) + upper = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [lower, upper], then translate + # to [2lower-1, 2upper-1]. + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor: Tensor, + mean: float = 0., + std: float = 1., + a: float = -2., + b: float = 2.) -> Tensor: + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Modified from + https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py + + Args: + tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`. + mean (float): the mean of the normal distribution. + std (float): the standard deviation of the normal distribution. + a (float): the minimum cutoff value. + b (float): the maximum cutoff value. + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/annotator/uniformer/mmcv/cnn/vgg.py b/annotator/uniformer/mmcv/cnn/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..8778b649561a45a9652b1a15a26c2d171e58f3e1 --- /dev/null +++ b/annotator/uniformer/mmcv/cnn/vgg.py @@ -0,0 +1,175 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +import torch.nn as nn + +from .utils import constant_init, kaiming_init, normal_init + + +def conv3x3(in_planes, out_planes, dilation=1): + """3x3 convolution with padding.""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + padding=dilation, + dilation=dilation) + + +def make_vgg_layer(inplanes, + planes, + num_blocks, + dilation=1, + with_bn=False, + ceil_mode=False): + layers = [] + for _ in range(num_blocks): + layers.append(conv3x3(inplanes, planes, dilation)) + if with_bn: + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + inplanes = planes + layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) + + return layers + + +class VGG(nn.Module): + """VGG backbone. + + Args: + depth (int): Depth of vgg, from {11, 13, 16, 19}. + with_bn (bool): Use BatchNorm or not. + num_classes (int): number of classes for classification. + num_stages (int): VGG stages, normally 5. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze + running stats (mean and var). + bn_frozen (bool): Whether to freeze weight and bias of BN layers. + """ + + arch_settings = { + 11: (1, 1, 2, 2, 2), + 13: (2, 2, 2, 2, 2), + 16: (2, 2, 3, 3, 3), + 19: (2, 2, 4, 4, 4) + } + + def __init__(self, + depth, + with_bn=False, + num_classes=-1, + num_stages=5, + dilations=(1, 1, 1, 1, 1), + out_indices=(0, 1, 2, 3, 4), + frozen_stages=-1, + bn_eval=True, + bn_frozen=False, + ceil_mode=False, + with_last_pool=True): + super(VGG, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for vgg') + assert num_stages >= 1 and num_stages <= 5 + stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + assert len(dilations) == num_stages + assert max(out_indices) <= num_stages + + self.num_classes = num_classes + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.bn_eval = bn_eval + self.bn_frozen = bn_frozen + + self.inplanes = 3 + start_idx = 0 + vgg_layers = [] + self.range_sub_modules = [] + for i, num_blocks in enumerate(self.stage_blocks): + num_modules = num_blocks * (2 + with_bn) + 1 + end_idx = start_idx + num_modules + dilation = dilations[i] + planes = 64 * 2**i if i < 4 else 512 + vgg_layer = make_vgg_layer( + self.inplanes, + planes, + num_blocks, + dilation=dilation, + with_bn=with_bn, + ceil_mode=ceil_mode) + vgg_layers.extend(vgg_layer) + self.inplanes = planes + self.range_sub_modules.append([start_idx, end_idx]) + start_idx = end_idx + if not with_last_pool: + vgg_layers.pop(-1) + self.range_sub_modules[-1][1] -= 1 + self.module_name = 'features' + self.add_module(self.module_name, nn.Sequential(*vgg_layers)) + + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + from ..runner import load_checkpoint + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + elif isinstance(m, nn.Linear): + normal_init(m, std=0.01) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + outs = [] + vgg_layers = getattr(self, self.module_name) + for i in range(len(self.stage_blocks)): + for j in range(*self.range_sub_modules[i]): + vgg_layer = vgg_layers[j] + x = vgg_layer(x) + if i in self.out_indices: + outs.append(x) + if self.num_classes > 0: + x = x.view(x.size(0), -1) + x = self.classifier(x) + outs.append(x) + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def train(self, mode=True): + super(VGG, self).train(mode) + if self.bn_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + if self.bn_frozen: + for params in m.parameters(): + params.requires_grad = False + vgg_layers = getattr(self, self.module_name) + if mode and self.frozen_stages >= 0: + for i in range(self.frozen_stages): + for j in range(*self.range_sub_modules[i]): + mod = vgg_layers[j] + mod.eval() + for param in mod.parameters(): + param.requires_grad = False diff --git a/annotator/uniformer/mmcv/engine/__init__.py b/annotator/uniformer/mmcv/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3193b7f664e19ce2458d81c836597fa22e4bb082 --- /dev/null +++ b/annotator/uniformer/mmcv/engine/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test, + single_gpu_test) + +__all__ = [ + 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test', + 'single_gpu_test' +] diff --git a/annotator/uniformer/mmcv/engine/test.py b/annotator/uniformer/mmcv/engine/test.py new file mode 100644 index 0000000000000000000000000000000000000000..8dbeef271db634ec2dadfda3bc0b5ef9c7a677ff --- /dev/null +++ b/annotator/uniformer/mmcv/engine/test.py @@ -0,0 +1,202 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import pickle +import shutil +import tempfile +import time + +import torch +import torch.distributed as dist + +import annotator.uniformer.mmcv as mmcv +from annotator.uniformer.mmcv.runner import get_dist_info + + +def single_gpu_test(model, data_loader): + """Test model with a single gpu. + + This method tests model with a single gpu and displays test progress bar. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + for data in data_loader: + with torch.no_grad(): + result = model(return_loss=False, **data) + results.extend(result) + + # Assume result has the same length of batch_size + # refer to https://github.com/open-mmlab/mmcv/issues/985 + batch_size = len(result) + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): + """Test model with multiple gpus. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting + ``gpu_collect=True``, it encodes results to gpu tensors and use gpu + communication for results collection. On cpu mode it saves the results on + different gpus to ``tmpdir`` and collects them by the rank 0 worker. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = mmcv.ProgressBar(len(dataset)) + time.sleep(2) # This line can prevent deadlock problem in some cases. + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + results.extend(result) + + if rank == 0: + batch_size = len(result) + batch_size_all = batch_size * world_size + if batch_size_all + prog_bar.completed > len(dataset): + batch_size_all = len(dataset) - prog_bar.completed + for _ in range(batch_size_all): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results + + +def collect_results_cpu(result_part, size, tmpdir=None): + """Collect results under cpu mode. + + On cpu mode, this function will save the results on different gpus to + ``tmpdir`` and collect them by the rank 0 worker. + + Args: + result_part (list): Result list containing result parts + to be collected. + size (int): Size of the results, commonly equal to length of + the results. + tmpdir (str | None): temporal directory for collected results to + store. If set to None, it will create a random temporal directory + for it. + + Returns: + list: The collected results. + """ + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + mmcv.mkdir_or_exist('.dist_test') + tmpdir = tempfile.mkdtemp(dir='.dist_test') + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmcv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, f'part_{i}.pkl') + part_result = mmcv.load(part_file) + # When data is severely insufficient, an empty part_result + # on a certain gpu could makes the overall outputs empty. + if part_result: + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + """Collect results under gpu mode. + + On gpu mode, this function will encode results to gpu tensors and use gpu + communication for results collection. + + Args: + result_part (list): Result list containing result parts + to be collected. + size (int): Size of the results, commonly equal to length of + the results. + + Returns: + list: The collected results. + """ + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) + # When data is severely insufficient, an empty part_result + # on a certain gpu could makes the overall outputs empty. + if part_result: + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results diff --git a/annotator/uniformer/mmcv/fileio/__init__.py b/annotator/uniformer/mmcv/fileio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2051b85f7e59bff7bdbaa131849ce8cd31f059a4 --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .file_client import BaseStorageBackend, FileClient +from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler +from .io import dump, load, register_handler +from .parse import dict_from_file, list_from_file + +__all__ = [ + 'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler', + 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler', + 'list_from_file', 'dict_from_file' +] diff --git a/annotator/uniformer/mmcv/fileio/file_client.py b/annotator/uniformer/mmcv/fileio/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..950f0c1aeab14b8e308a7455ccd64a95b5d98add --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/file_client.py @@ -0,0 +1,1148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import os +import os.path as osp +import re +import tempfile +import warnings +from abc import ABCMeta, abstractmethod +from contextlib import contextmanager +from pathlib import Path +from typing import Iterable, Iterator, Optional, Tuple, Union +from urllib.request import urlopen + +import annotator.uniformer.mmcv as mmcv +from annotator.uniformer.mmcv.utils.misc import has_method +from annotator.uniformer.mmcv.utils.path import is_filepath + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + # a flag to indicate whether the backend can create a symlink for a file + _allow_symlink = False + + @property + def name(self): + return self.__class__.__name__ + + @property + def allow_symlink(self): + return self._allow_symlink + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class CephBackend(BaseStorageBackend): + """Ceph storage backend (for internal use). + + Args: + path_mapping (dict|None): path mapping dict from local path to Petrel + path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath`` + will be replaced by ``dst``. Default: None. + + .. warning:: + :class:`mmcv.fileio.file_client.CephBackend` will be deprecated, + please use :class:`mmcv.fileio.file_client.PetrelBackend` instead. + """ + + def __init__(self, path_mapping=None): + try: + import ceph + except ImportError: + raise ImportError('Please install ceph to enable CephBackend.') + + warnings.warn( + 'CephBackend will be deprecated, please use PetrelBackend instead') + self._client = ceph.S3Client() + assert isinstance(path_mapping, dict) or path_mapping is None + self.path_mapping = path_mapping + + def get(self, filepath): + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v) + value = self._client.Get(filepath) + value_buf = memoryview(value) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + +class PetrelBackend(BaseStorageBackend): + """Petrel storage backend (for internal use). + + PetrelBackend supports reading and writing data to multiple clusters. + If the file path contains the cluster name, PetrelBackend will read data + from specified cluster or write data to it. Otherwise, PetrelBackend will + access the default cluster. + + Args: + path_mapping (dict, optional): Path mapping dict from local path to + Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in + ``filepath`` will be replaced by ``dst``. Default: None. + enable_mc (bool, optional): Whether to enable memcached support. + Default: True. + + Examples: + >>> filepath1 = 's3://path/of/file' + >>> filepath2 = 'cluster-name:s3://path/of/file' + >>> client = PetrelBackend() + >>> client.get(filepath1) # get data from default cluster + >>> client.get(filepath2) # get data from 'cluster-name' cluster + """ + + def __init__(self, + path_mapping: Optional[dict] = None, + enable_mc: bool = True): + try: + from petrel_client import client + except ImportError: + raise ImportError('Please install petrel_client to enable ' + 'PetrelBackend.') + + self._client = client.Client(enable_mc=enable_mc) + assert isinstance(path_mapping, dict) or path_mapping is None + self.path_mapping = path_mapping + + def _map_path(self, filepath: Union[str, Path]) -> str: + """Map ``filepath`` to a string path whose prefix will be replaced by + :attr:`self.path_mapping`. + + Args: + filepath (str): Path to be mapped. + """ + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v) + return filepath + + def _format_path(self, filepath: str) -> str: + """Convert a ``filepath`` to standard format of petrel oss. + + If the ``filepath`` is concatenated by ``os.path.join``, in a Windows + environment, the ``filepath`` will be the format of + 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the + above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. + + Args: + filepath (str): Path to be formatted. + """ + return re.sub(r'\\+', '/', filepath) + + def get(self, filepath: Union[str, Path]) -> memoryview: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + memoryview: A memory view of expected bytes object to avoid + copying. The memoryview object can be converted to bytes by + ``value_buf.tobytes()``. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + value = self._client.Get(filepath) + value_buf = memoryview(value) + return value_buf + + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return str(self.get(filepath), encoding=encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Save data to a given ``filepath``. + + Args: + obj (bytes): Data to be saved. + filepath (str or Path): Path to write data. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + self._client.put(filepath, obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Save data to a given ``filepath``. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to encode the ``obj``. + Default: 'utf-8'. + """ + self.put(bytes(obj, encoding=encoding), filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + """ + if not has_method(self._client, 'delete'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `delete` method, please use a higher version or dev' + ' branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + self._client.delete(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + if not (has_method(self._client, 'contains') + and has_method(self._client, 'isdir')): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `contains` and `isdir` methods, please use a higher' + 'version or dev branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.contains(filepath) or self._client.isdir(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + if not has_method(self._client, 'isdir'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `isdir` method, please use a higher version or dev' + ' branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + if not has_method(self._client, 'contains'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `contains` method, please use a higher version or ' + 'dev branch instead.')) + + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + return self._client.contains(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result after concatenation. + """ + filepath = self._format_path(self._map_path(filepath)) + if filepath.endswith('/'): + filepath = filepath[:-1] + formatted_paths = [filepath] + for path in filepaths: + formatted_paths.append(self._format_path(self._map_path(path))) + return '/'.join(formatted_paths) + + @contextmanager + def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: + """Download a file from ``filepath`` and return a temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str | Path): Download a file from ``filepath``. + + Examples: + >>> client = PetrelBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with client.get_local_path('s3://path/of/your/file') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one temporary path. + """ + filepath = self._map_path(filepath) + filepath = self._format_path(filepath) + assert self.isfile(filepath) + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + Petrel has no concept of directories but it simulates the directory + hierarchy in the filesystem through public prefixes. In addition, + if the returned path ends with '/', it means the path is a public + prefix which is a logical directory. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + In addition, the returned path of directory will not contains the + suffix '/' which is consistent with other backends. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + if not has_method(self._client, 'list'): + raise NotImplementedError( + ('Current version of Petrel Python SDK has not supported ' + 'the `list` method, please use a higher version or dev' + ' branch instead.')) + + dir_path = self._map_path(dir_path) + dir_path = self._format_path(dir_path) + if list_dir and suffix is not None: + raise TypeError( + '`list_dir` should be False when `suffix` is not None') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + # Petrel's simulated directory hierarchy assumes that directory paths + # should end with `/` + if not dir_path.endswith('/'): + dir_path += '/' + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for path in self._client.list(dir_path): + # the `self.isdir` is not used here to determine whether path + # is a directory, because `self.isdir` relies on + # `self._client.list` + if path.endswith('/'): # a directory path + next_dir_path = self.join_path(dir_path, path) + if list_dir: + # get the relative path and exclude the last + # character '/' + rel_dir = next_dir_path[len(root):-1] + yield rel_dir + if recursive: + yield from _list_dir_or_file(next_dir_path, list_dir, + list_file, suffix, + recursive) + else: # a file path + absolute_path = self.join_path(dir_path, path) + rel_path = absolute_path[len(root):] + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError( + 'Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, + self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_path (str): Lmdb database path. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_path (str): Lmdb database path. + """ + + def __init__(self, + db_path, + readonly=True, + lock=False, + readahead=False, + **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + self.db_path = str(db_path) + self._client = lmdb.open( + self.db_path, + readonly=readonly, + lock=lock, + readahead=readahead, + **kwargs) + + def get(self, filepath): + """Get values according to the filepath. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + """ + filepath = str(filepath) + with self._client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath, encoding=None): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + _allow_symlink = True + + def get(self, filepath: Union[str, Path]) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + with open(filepath, 'r', encoding=encoding) as f: + value_buf = f.read() + return value_buf + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + mmcv.mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, 'wb') as f: + f.write(obj) + + def put_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + mmcv.mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + """ + os.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return osp.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return osp.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return osp.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return osp.join(filepath, *filepaths) + + @contextmanager + def get_local_path( + self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]: + """Only for unified API and do nothing.""" + yield filepath + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + if list_dir and suffix is not None: + raise TypeError('`suffix` should be None when `list_dir` is True') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('`suffix` must be a string or tuple of strings') + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if (suffix is None + or rel_path.endswith(suffix)) and list_file: + yield rel_path + elif osp.isdir(entry.path): + if list_dir: + rel_dir = osp.relpath(entry.path, root) + yield rel_dir + if recursive: + yield from _list_dir_or_file(entry.path, list_dir, + list_file, suffix, + recursive) + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, + recursive) + + +class HTTPBackend(BaseStorageBackend): + """HTTP and HTTPS storage bachend.""" + + def get(self, filepath): + value_buf = urlopen(filepath).read() + return value_buf + + def get_text(self, filepath, encoding='utf-8'): + value_buf = urlopen(filepath).read() + return value_buf.decode(encoding) + + @contextmanager + def get_local_path(self, filepath: str) -> Iterable[str]: + """Download a file from ``filepath``. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> client = HTTPBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with client.get_local_path('http://path/of/your/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + +class FileClient: + """A general file client to access files in different backends. + + The client loads a file or text in a specified backend from its path + and returns it as a binary or text file. There are two ways to choose a + backend, the name of backend and the prefix of path. Although both of them + can be used to choose a storage backend, ``backend`` has a higher priority + that is if they are all set, the storage backend will be chosen by the + backend argument. If they are all `None`, the disk backend will be chosen. + Note that It can also register other backend accessor with a given name, + prefixes, and backend class. In addition, We use the singleton pattern to + avoid repeated object creation. If the arguments are the same, the same + object will be returned. + + Args: + backend (str, optional): The storage backend type. Options are "disk", + "ceph", "memcached", "lmdb", "http" and "petrel". Default: None. + prefix (str, optional): The prefix of the registered storage backend. + Options are "s3", "http", "https". Default: None. + + Examples: + >>> # only set backend + >>> file_client = FileClient(backend='petrel') + >>> # only set prefix + >>> file_client = FileClient(prefix='s3') + >>> # set both backend and prefix but use backend to choose client + >>> file_client = FileClient(backend='petrel', prefix='s3') + >>> # if the arguments are the same, the same object is returned + >>> file_client1 = FileClient(backend='petrel') + >>> file_client1 is file_client + True + + Attributes: + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'ceph': CephBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + 'petrel': PetrelBackend, + 'http': HTTPBackend, + } + # This collection is used to record the overridden backends, and when a + # backend appears in the collection, the singleton pattern is disabled for + # that backend, because if the singleton pattern is used, then the object + # returned will be the backend before overwriting + _overridden_backends = set() + _prefix_to_backends = { + 's3': PetrelBackend, + 'http': HTTPBackend, + 'https': HTTPBackend, + } + _overridden_prefixes = set() + + _instances = {} + + def __new__(cls, backend=None, prefix=None, **kwargs): + if backend is None and prefix is None: + backend = 'disk' + if backend is not None and backend not in cls._backends: + raise ValueError( + f'Backend {backend} is not supported. Currently supported ones' + f' are {list(cls._backends.keys())}') + if prefix is not None and prefix not in cls._prefix_to_backends: + raise ValueError( + f'prefix {prefix} is not supported. Currently supported ones ' + f'are {list(cls._prefix_to_backends.keys())}') + + # concatenate the arguments to a unique key for determining whether + # objects with the same arguments were created + arg_key = f'{backend}:{prefix}' + for key, value in kwargs.items(): + arg_key += f':{key}:{value}' + + # if a backend was overridden, it will create a new object + if (arg_key in cls._instances + and backend not in cls._overridden_backends + and prefix not in cls._overridden_prefixes): + _instance = cls._instances[arg_key] + else: + # create a new object and put it to _instance + _instance = super().__new__(cls) + if backend is not None: + _instance.client = cls._backends[backend](**kwargs) + else: + _instance.client = cls._prefix_to_backends[prefix](**kwargs) + + cls._instances[arg_key] = _instance + + return _instance + + @property + def name(self): + return self.client.name + + @property + def allow_symlink(self): + return self.client.allow_symlink + + @staticmethod + def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: + """Parse the prefix of a uri. + + Args: + uri (str | Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> FileClient.parse_uri_prefix('s3://path/of/your/file') + 's3' + + Returns: + str | None: Return the prefix of uri if the uri contains '://' + else ``None``. + """ + assert is_filepath(uri) + uri = str(uri) + if '://' not in uri: + return None + else: + prefix, _ = uri.split('://') + # In the case of PetrelBackend, the prefix may contains the cluster + # name like clusterName:s3 + if ':' in prefix: + _, prefix = prefix.split(':') + return prefix + + @classmethod + def infer_client(cls, + file_client_args: Optional[dict] = None, + uri: Optional[Union[str, Path]] = None) -> 'FileClient': + """Infer a suitable file client based on the URI and arguments. + + Args: + file_client_args (dict, optional): Arguments to instantiate a + FileClient. Default: None. + uri (str | Path, optional): Uri to be parsed that contains the file + prefix. Default: None. + + Examples: + >>> uri = 's3://path/of/your/file' + >>> file_client = FileClient.infer_client(uri=uri) + >>> file_client_args = {'backend': 'petrel'} + >>> file_client = FileClient.infer_client(file_client_args) + + Returns: + FileClient: Instantiated FileClient object. + """ + assert file_client_args is not None or uri is not None + if file_client_args is None: + file_prefix = cls.parse_uri_prefix(uri) # type: ignore + return cls(prefix=file_prefix) + else: + return cls(**file_client_args) + + @classmethod + def _register_backend(cls, name, backend, force=False, prefixes=None): + if not isinstance(name, str): + raise TypeError('the backend name should be a string, ' + f'but got {type(name)}') + if not inspect.isclass(backend): + raise TypeError( + f'backend should be a class but got {type(backend)}') + if not issubclass(backend, BaseStorageBackend): + raise TypeError( + f'backend {backend} is not a subclass of BaseStorageBackend') + if not force and name in cls._backends: + raise KeyError( + f'{name} is already registered as a storage backend, ' + 'add "force=True" if you want to override it') + + if name in cls._backends and force: + cls._overridden_backends.add(name) + cls._backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if prefix not in cls._prefix_to_backends: + cls._prefix_to_backends[prefix] = backend + elif (prefix in cls._prefix_to_backends) and force: + cls._overridden_prefixes.add(prefix) + cls._prefix_to_backends[prefix] = backend + else: + raise KeyError( + f'{prefix} is already registered as a storage backend,' + ' add "force=True" if you want to override it') + + @classmethod + def register_backend(cls, name, backend=None, force=False, prefixes=None): + """Register a backend to FileClient. + + This method can be used as a normal class method or a decorator. + + .. code-block:: python + + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + FileClient.register_backend('new', NewBackend) + + or + + .. code-block:: python + + @FileClient.register_backend('new') + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool, optional): Whether to override the backend if the name + has already been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefixes + of the registered storage backend. Default: None. + `New in version 1.3.15.` + """ + if backend is not None: + cls._register_backend( + name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + cls._register_backend( + name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: + """Read data from a given ``filepath`` with 'rb' mode. + + Note: + There are two types of return values for ``get``, one is ``bytes`` + and the other is ``memoryview``. The advantage of using memoryview + is that you can avoid copying, and if you want to convert it to + ``bytes``, you can use ``.tobytes()``. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes | memoryview: Expected bytes object or a memory view of the + bytes object. + """ + return self.client.get(filepath) + + def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return self.client.get_text(filepath, encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` should create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + self.client.put(obj, filepath) + + def put_text(self, obj: str, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + `filepath`. Default: 'utf-8'. + """ + self.client.put_text(obj, filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + """ + self.client.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return self.client.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return self.client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return self.client.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], + *filepaths: Union[str, Path]) -> str: + """Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of *filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return self.client.join_path(filepath, *filepaths) + + @contextmanager + def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself. + + .. warning:: + ``get_local_path`` is an experimental interface that may change in + the future. + + Args: + filepath (str or Path): Path to be read data. + + Examples: + >>> file_client = FileClient(prefix='s3') + >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one path. + """ + with self.client.get_local_path(str(filepath)) as local_path: + yield local_path + + def list_dir_or_file(self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Default: True. + list_file (bool): List the path of files. Default: True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Default: None. + recursive (bool): If set to True, recursively scan the + directory. Default: False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, + suffix, recursive) diff --git a/annotator/uniformer/mmcv/fileio/handlers/__init__.py b/annotator/uniformer/mmcv/fileio/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa24d91972837b8756b225f4879bac20436eb72a --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/handlers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseFileHandler +from .json_handler import JsonHandler +from .pickle_handler import PickleHandler +from .yaml_handler import YamlHandler + +__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler'] diff --git a/annotator/uniformer/mmcv/fileio/handlers/base.py b/annotator/uniformer/mmcv/fileio/handlers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..288878bc57282fbb2f12b32290152ca8e9d3cab0 --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/handlers/base.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class BaseFileHandler(metaclass=ABCMeta): + # `str_like` is a flag to indicate whether the type of file object is + # str-like object or bytes-like object. Pickle only processes bytes-like + # objects but json only processes str-like object. If it is str-like + # object, `StringIO` will be used to process the buffer. + str_like = True + + @abstractmethod + def load_from_fileobj(self, file, **kwargs): + pass + + @abstractmethod + def dump_to_fileobj(self, obj, file, **kwargs): + pass + + @abstractmethod + def dump_to_str(self, obj, **kwargs): + pass + + def load_from_path(self, filepath, mode='r', **kwargs): + with open(filepath, mode) as f: + return self.load_from_fileobj(f, **kwargs) + + def dump_to_path(self, obj, filepath, mode='w', **kwargs): + with open(filepath, mode) as f: + self.dump_to_fileobj(obj, f, **kwargs) diff --git a/annotator/uniformer/mmcv/fileio/handlers/json_handler.py b/annotator/uniformer/mmcv/fileio/handlers/json_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..18d4f15f74139d20adff18b20be5529c592a66b6 --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/handlers/json_handler.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json + +import numpy as np + +from .base import BaseFileHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f'{type(obj)} is unsupported for json dump') + + +class JsonHandler(BaseFileHandler): + + def load_from_fileobj(self, file): + return json.load(file) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('default', set_default) + json.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('default', set_default) + return json.dumps(obj, **kwargs) diff --git a/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py b/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b37c79bed4ef9fd8913715e62dbe3fc5cafdc3aa --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pickle + +from .base import BaseFileHandler + + +class PickleHandler(BaseFileHandler): + + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return pickle.load(file, **kwargs) + + def load_from_path(self, filepath, **kwargs): + return super(PickleHandler, self).load_from_path( + filepath, mode='rb', **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('protocol', 2) + return pickle.dumps(obj, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('protocol', 2) + pickle.dump(obj, file, **kwargs) + + def dump_to_path(self, obj, filepath, **kwargs): + super(PickleHandler, self).dump_to_path( + obj, filepath, mode='wb', **kwargs) diff --git a/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py b/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c5aa2eea1e8c76f8baf753d1c8c959dee665e543 --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import yaml + +try: + from yaml import CLoader as Loader, CDumper as Dumper +except ImportError: + from yaml import Loader, Dumper + +from .base import BaseFileHandler # isort:skip + + +class YamlHandler(BaseFileHandler): + + def load_from_fileobj(self, file, **kwargs): + kwargs.setdefault('Loader', Loader) + return yaml.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('Dumper', Dumper) + yaml.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('Dumper', Dumper) + return yaml.dump(obj, **kwargs) diff --git a/annotator/uniformer/mmcv/fileio/io.py b/annotator/uniformer/mmcv/fileio/io.py new file mode 100644 index 0000000000000000000000000000000000000000..aaefde58aa3ea5b58f86249ce7e1c40c186eb8dd --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/io.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from io import BytesIO, StringIO +from pathlib import Path + +from ..utils import is_list_of, is_str +from .file_client import FileClient +from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler + +file_handlers = { + 'json': JsonHandler(), + 'yaml': YamlHandler(), + 'yml': YamlHandler(), + 'pickle': PickleHandler(), + 'pkl': PickleHandler() +} + + +def load(file, file_format=None, file_client_args=None, **kwargs): + """Load data from json/yaml/pickle files. + + This method provides a unified api for loading data from serialized files. + + Note: + In v1.3.16 and later, ``load`` supports loading data from serialized + files those can be storaged in different backends. + + Args: + file (str or :obj:`Path` or file-like object): Filename or a file-like + object. + file_format (str, optional): If not specified, the file format will be + inferred from the file extension, otherwise use the specified one. + Currently supported formats include "json", "yaml/yml" and + "pickle/pkl". + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('s3://path/of/your/file') # file is storaged in petrel + + Returns: + The content from the file. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None and is_str(file): + file_format = file.split('.')[-1] + if file_format not in file_handlers: + raise TypeError(f'Unsupported format: {file_format}') + + handler = file_handlers[file_format] + if is_str(file): + file_client = FileClient.infer_client(file_client_args, file) + if handler.str_like: + with StringIO(file_client.get_text(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + with BytesIO(file_client.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + elif hasattr(file, 'read'): + obj = handler.load_from_fileobj(file, **kwargs) + else: + raise TypeError('"file" must be a filepath str or a file-object') + return obj + + +def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): + """Dump data to json/yaml/pickle strings or files. + + This method provides a unified api for dumping data as strings or to files, + and also supports custom arguments for each file format. + + Note: + In v1.3.16 and later, ``dump`` supports dumping data as strings or to + files which is saved to different backends. + + Args: + obj (any): The python object to be dumped. + file (str or :obj:`Path` or file-like object, optional): If not + specified, then the object is dumped to a str, otherwise to a file + specified by the filename or file-like object. + file_format (str, optional): Same as :func:`load`. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> dump('hello world', '/path/of/your/file') # disk + >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel + + Returns: + bool: True for success, False otherwise. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None: + if is_str(file): + file_format = file.split('.')[-1] + elif file is None: + raise ValueError( + 'file_format must be specified since file is None') + if file_format not in file_handlers: + raise TypeError(f'Unsupported format: {file_format}') + + handler = file_handlers[file_format] + if file is None: + return handler.dump_to_str(obj, **kwargs) + elif is_str(file): + file_client = FileClient.infer_client(file_client_args, file) + if handler.str_like: + with StringIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_client.put_text(f.getvalue(), file) + else: + with BytesIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_client.put(f.getvalue(), file) + elif hasattr(file, 'write'): + handler.dump_to_fileobj(obj, file, **kwargs) + else: + raise TypeError('"file" must be a filename str or a file-object') + + +def _register_handler(handler, file_formats): + """Register a handler for some file extensions. + + Args: + handler (:obj:`BaseFileHandler`): Handler to be registered. + file_formats (str or list[str]): File formats to be handled by this + handler. + """ + if not isinstance(handler, BaseFileHandler): + raise TypeError( + f'handler must be a child of BaseFileHandler, not {type(handler)}') + if isinstance(file_formats, str): + file_formats = [file_formats] + if not is_list_of(file_formats, str): + raise TypeError('file_formats must be a str or a list of str') + for ext in file_formats: + file_handlers[ext] = handler + + +def register_handler(file_formats, **kwargs): + + def wrap(cls): + _register_handler(cls(**kwargs), file_formats) + return cls + + return wrap diff --git a/annotator/uniformer/mmcv/fileio/parse.py b/annotator/uniformer/mmcv/fileio/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..f60f0d611b8d75692221d0edd7dc993b0a6445c9 --- /dev/null +++ b/annotator/uniformer/mmcv/fileio/parse.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from io import StringIO + +from .file_client import FileClient + + +def list_from_file(filename, + prefix='', + offset=0, + max_num=0, + encoding='utf-8', + file_client_args=None): + """Load a text file and parse the content as a list of strings. + + Note: + In v1.3.16 and later, ``list_from_file`` supports loading a text file + which can be storaged in different backends and parsing the content as + a list for strings. + + Args: + filename (str): Filename. + prefix (str): The prefix to be inserted to the beginning of each item. + offset (int): The offset of lines. + max_num (int): The maximum number of lines to be read, + zeros and negatives mean no limitation. + encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> list_from_file('/path/of/your/file') # disk + ['hello', 'world'] + >>> list_from_file('s3://path/of/your/file') # ceph or petrel + ['hello', 'world'] + + Returns: + list[str]: A list of strings. + """ + cnt = 0 + item_list = [] + file_client = FileClient.infer_client(file_client_args, filename) + with StringIO(file_client.get_text(filename, encoding)) as f: + for _ in range(offset): + f.readline() + for line in f: + if 0 < max_num <= cnt: + break + item_list.append(prefix + line.rstrip('\n\r')) + cnt += 1 + return item_list + + +def dict_from_file(filename, + key_type=str, + encoding='utf-8', + file_client_args=None): + """Load a text file and parse the content as a dict. + + Each line of the text file will be two or more columns split by + whitespaces or tabs. The first column will be parsed as dict keys, and + the following columns will be parsed as dict values. + + Note: + In v1.3.16 and later, ``dict_from_file`` supports loading a text file + which can be storaged in different backends and parsing the content as + a dict. + + Args: + filename(str): Filename. + key_type(type): Type of the dict keys. str is user by default and + type conversion will be performed if specified. + encoding (str): Encoding used to open the file. Default utf-8. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + + Examples: + >>> dict_from_file('/path/of/your/file') # disk + {'key1': 'value1', 'key2': 'value2'} + >>> dict_from_file('s3://path/of/your/file') # ceph or petrel + {'key1': 'value1', 'key2': 'value2'} + + Returns: + dict: The parsed contents. + """ + mapping = {} + file_client = FileClient.infer_client(file_client_args, filename) + with StringIO(file_client.get_text(filename, encoding)) as f: + for line in f: + items = line.rstrip('\n').split() + assert len(items) >= 2 + key = key_type(items[0]) + val = items[1:] if len(items) > 2 else items[1] + mapping[key] = val + return mapping diff --git a/annotator/uniformer/mmcv/image/__init__.py b/annotator/uniformer/mmcv/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0051d609d3de4e7562e3fe638335c66617c4d91 --- /dev/null +++ b/annotator/uniformer/mmcv/image/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr, + gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert, + rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb) +from .geometric import (cutout, imcrop, imflip, imflip_, impad, + impad_to_multiple, imrescale, imresize, imresize_like, + imresize_to_multiple, imrotate, imshear, imtranslate, + rescale_size) +from .io import imfrombytes, imread, imwrite, supported_backends, use_backend +from .misc import tensor2imgs +from .photometric import (adjust_brightness, adjust_color, adjust_contrast, + adjust_lighting, adjust_sharpness, auto_contrast, + clahe, imdenormalize, imequalize, iminvert, + imnormalize, imnormalize_, lut_transform, posterize, + solarize) + +__all__ = [ + 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb', + 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale', + 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size', + 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate', + 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend', + 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize', + 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr', + 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize', + 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe', + 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting' +] diff --git a/annotator/uniformer/mmcv/image/colorspace.py b/annotator/uniformer/mmcv/image/colorspace.py new file mode 100644 index 0000000000000000000000000000000000000000..814533952fdfda23d67cb6a3073692d8c1156add --- /dev/null +++ b/annotator/uniformer/mmcv/image/colorspace.py @@ -0,0 +1,306 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np + + +def imconvert(img, src, dst): + """Convert an image from the src colorspace to dst colorspace. + + Args: + img (ndarray): The input image. + src (str): The source colorspace, e.g., 'rgb', 'hsv'. + dst (str): The destination colorspace, e.g., 'rgb', 'hsv'. + + Returns: + ndarray: The converted image. + """ + code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}') + out_img = cv2.cvtColor(img, code) + return out_img + + +def bgr2gray(img, keepdim=False): + """Convert a BGR image to grayscale image. + + Args: + img (ndarray): The input image. + keepdim (bool): If False (by default), then return the grayscale image + with 2 dims, otherwise 3 dims. + + Returns: + ndarray: The converted grayscale image. + """ + out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + if keepdim: + out_img = out_img[..., None] + return out_img + + +def rgb2gray(img, keepdim=False): + """Convert a RGB image to grayscale image. + + Args: + img (ndarray): The input image. + keepdim (bool): If False (by default), then return the grayscale image + with 2 dims, otherwise 3 dims. + + Returns: + ndarray: The converted grayscale image. + """ + out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + if keepdim: + out_img = out_img[..., None] + return out_img + + +def gray2bgr(img): + """Convert a grayscale image to BGR image. + + Args: + img (ndarray): The input image. + + Returns: + ndarray: The converted BGR image. + """ + img = img[..., None] if img.ndim == 2 else img + out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + return out_img + + +def gray2rgb(img): + """Convert a grayscale image to RGB image. + + Args: + img (ndarray): The input image. + + Returns: + ndarray: The converted RGB image. + """ + img = img[..., None] if img.ndim == 2 else img + out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' + f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' + f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [ + -222.921, 135.576, -276.836 + ] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [ + -276.836, 135.576, -222.921 + ] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def convert_color_factory(src, dst): + + code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}') + + def convert_color(img): + out_img = cv2.cvtColor(img, code) + return out_img + + convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()} + image. + + Args: + img (ndarray or str): The input image. + + Returns: + ndarray: The converted {dst.upper()} image. + """ + + return convert_color + + +bgr2rgb = convert_color_factory('bgr', 'rgb') + +rgb2bgr = convert_color_factory('rgb', 'bgr') + +bgr2hsv = convert_color_factory('bgr', 'hsv') + +hsv2bgr = convert_color_factory('hsv', 'bgr') + +bgr2hls = convert_color_factory('bgr', 'hls') + +hls2bgr = convert_color_factory('hls', 'bgr') diff --git a/annotator/uniformer/mmcv/image/geometric.py b/annotator/uniformer/mmcv/image/geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..cf97c201cb4e43796c911919d03fb26a07ed817d --- /dev/null +++ b/annotator/uniformer/mmcv/image/geometric.py @@ -0,0 +1,728 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numbers + +import cv2 +import numpy as np + +from ..utils import to_2tuple +from .io import imread_backend + +try: + from PIL import Image +except ImportError: + Image = None + + +def _scale_size(size, scale): + """Rescale a size by a ratio. + + Args: + size (tuple[int]): (w, h). + scale (float | tuple(float)): Scaling factor. + + Returns: + tuple[int]: scaled size. + """ + if isinstance(scale, (float, int)): + scale = (scale, scale) + w, h = size + return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5) + + +cv2_interp_codes = { + 'nearest': cv2.INTER_NEAREST, + 'bilinear': cv2.INTER_LINEAR, + 'bicubic': cv2.INTER_CUBIC, + 'area': cv2.INTER_AREA, + 'lanczos': cv2.INTER_LANCZOS4 +} + +if Image is not None: + pillow_interp_codes = { + 'nearest': Image.NEAREST, + 'bilinear': Image.BILINEAR, + 'bicubic': Image.BICUBIC, + 'box': Image.BOX, + 'lanczos': Image.LANCZOS, + 'hamming': Image.HAMMING + } + + +def imresize(img, + size, + return_scale=False, + interpolation='bilinear', + out=None, + backend=None): + """Resize image to a given size. + + Args: + img (ndarray): The input image. + size (tuple[int]): Target size (w, h). + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``mmcv.use_backend()`` will be used. Default: None. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if backend is None: + backend = imread_backend + if backend not in ['cv2', 'pillow']: + raise ValueError(f'backend: {backend} is not supported for resize.' + f"Supported backends are 'cv2', 'pillow'") + + if backend == 'pillow': + assert img.dtype == np.uint8, 'Pillow backend only support uint8 type' + pil_image = Image.fromarray(img) + pil_image = pil_image.resize(size, pillow_interp_codes[interpolation]) + resized_img = np.array(pil_image) + else: + resized_img = cv2.resize( + img, size, dst=out, interpolation=cv2_interp_codes[interpolation]) + if not return_scale: + return resized_img + else: + w_scale = size[0] / w + h_scale = size[1] / h + return resized_img, w_scale, h_scale + + +def imresize_to_multiple(img, + divisor, + size=None, + scale_factor=None, + keep_ratio=False, + return_scale=False, + interpolation='bilinear', + out=None, + backend=None): + """Resize image according to a given size or scale factor and then rounds + up the the resized or rescaled image size to the nearest value that can be + divided by the divisor. + + Args: + img (ndarray): The input image. + divisor (int | tuple): Resized image size will be a multiple of + divisor. If divisor is a tuple, divisor should be + (w_divisor, h_divisor). + size (None | int | tuple[int]): Target size (w, h). Default: None. + scale_factor (None | float | tuple[float]): Multiplier for spatial + size. Should match input size if it is a tuple and the 2D style is + (w_scale_factor, h_scale_factor). Default: None. + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Default: False. + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``mmcv.use_backend()`` will be used. Default: None. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if size is not None and scale_factor is not None: + raise ValueError('only one of size or scale_factor should be defined') + elif size is None and scale_factor is None: + raise ValueError('one of size or scale_factor should be defined') + elif size is not None: + size = to_2tuple(size) + if keep_ratio: + size = rescale_size((w, h), size, return_scale=False) + else: + size = _scale_size((w, h), scale_factor) + + divisor = to_2tuple(divisor) + size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)]) + resized_img, w_scale, h_scale = imresize( + img, + size, + return_scale=True, + interpolation=interpolation, + out=out, + backend=backend) + if return_scale: + return resized_img, w_scale, h_scale + else: + return resized_img + + +def imresize_like(img, + dst_img, + return_scale=False, + interpolation='bilinear', + backend=None): + """Resize image to the same size of a given image. + + Args: + img (ndarray): The input image. + dst_img (ndarray): The target image. + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = dst_img.shape[:2] + return imresize(img, (w, h), return_scale, interpolation, backend=backend) + + +def rescale_size(old_size, scale, return_scale=False): + """Calculate the new size to be rescaled to. + + Args: + old_size (tuple[int]): The old size (w, h) of image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image size. + + Returns: + tuple[int]: The new rescaled image size. + """ + w, h = old_size + if isinstance(scale, (float, int)): + if scale <= 0: + raise ValueError(f'Invalid scale {scale}, must be positive.') + scale_factor = scale + elif isinstance(scale, tuple): + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) + else: + raise TypeError( + f'Scale must be a number or tuple of int, but got {type(scale)}') + + new_size = _scale_size((w, h), scale_factor) + + if return_scale: + return new_size, scale_factor + else: + return new_size + + +def imrescale(img, + scale, + return_scale=False, + interpolation='bilinear', + backend=None): + """Resize image while keeping the aspect ratio. + + Args: + img (ndarray): The input image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + ndarray: The rescaled image. + """ + h, w = img.shape[:2] + new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) + rescaled_img = imresize( + img, new_size, interpolation=interpolation, backend=backend) + if return_scale: + return rescaled_img, scale_factor + else: + return rescaled_img + + +def imflip(img, direction='horizontal'): + """Flip an image horizontally or vertically. + + Args: + img (ndarray): Image to be flipped. + direction (str): The flip direction, either "horizontal" or + "vertical" or "diagonal". + + Returns: + ndarray: The flipped image. + """ + assert direction in ['horizontal', 'vertical', 'diagonal'] + if direction == 'horizontal': + return np.flip(img, axis=1) + elif direction == 'vertical': + return np.flip(img, axis=0) + else: + return np.flip(img, axis=(0, 1)) + + +def imflip_(img, direction='horizontal'): + """Inplace flip an image horizontally or vertically. + + Args: + img (ndarray): Image to be flipped. + direction (str): The flip direction, either "horizontal" or + "vertical" or "diagonal". + + Returns: + ndarray: The flipped image (inplace). + """ + assert direction in ['horizontal', 'vertical', 'diagonal'] + if direction == 'horizontal': + return cv2.flip(img, 1, img) + elif direction == 'vertical': + return cv2.flip(img, 0, img) + else: + return cv2.flip(img, -1, img) + + +def imrotate(img, + angle, + center=None, + scale=1.0, + border_value=0, + interpolation='bilinear', + auto_bound=False): + """Rotate an image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees, positive values mean + clockwise rotation. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. + scale (float): Isotropic scale factor. + border_value (int): Border value. + interpolation (str): Same as :func:`resize`. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. + + Returns: + ndarray: The rotated image. + """ + if center is not None and auto_bound: + raise ValueError('`auto_bound` conflicts with `center`') + h, w = img.shape[:2] + if center is None: + center = ((w - 1) * 0.5, (h - 1) * 0.5) + assert isinstance(center, tuple) + + matrix = cv2.getRotationMatrix2D(center, -angle, scale) + if auto_bound: + cos = np.abs(matrix[0, 0]) + sin = np.abs(matrix[0, 1]) + new_w = h * sin + w * cos + new_h = h * cos + w * sin + matrix[0, 2] += (new_w - w) * 0.5 + matrix[1, 2] += (new_h - h) * 0.5 + w = int(np.round(new_w)) + h = int(np.round(new_h)) + rotated = cv2.warpAffine( + img, + matrix, (w, h), + flags=cv2_interp_codes[interpolation], + borderValue=border_value) + return rotated + + +def bbox_clip(bboxes, img_shape): + """Clip bboxes to fit the image shape. + + Args: + bboxes (ndarray): Shape (..., 4*k) + img_shape (tuple[int]): (height, width) of the image. + + Returns: + ndarray: Clipped bboxes. + """ + assert bboxes.shape[-1] % 4 == 0 + cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype) + cmin[0::2] = img_shape[1] - 1 + cmin[1::2] = img_shape[0] - 1 + clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0) + return clipped_bboxes + + +def bbox_scaling(bboxes, scale, clip_shape=None): + """Scaling bboxes w.r.t the box center. + + Args: + bboxes (ndarray): Shape(..., 4). + scale (float): Scaling factor. + clip_shape (tuple[int], optional): If specified, bboxes that exceed the + boundary will be clipped according to the given shape (h, w). + + Returns: + ndarray: Scaled bboxes. + """ + if float(scale) == 1.0: + scaled_bboxes = bboxes.copy() + else: + w = bboxes[..., 2] - bboxes[..., 0] + 1 + h = bboxes[..., 3] - bboxes[..., 1] + 1 + dw = (w * (scale - 1)) * 0.5 + dh = (h * (scale - 1)) * 0.5 + scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1) + if clip_shape is not None: + return bbox_clip(scaled_bboxes, clip_shape) + else: + return scaled_bboxes + + +def imcrop(img, bboxes, scale=1.0, pad_fill=None): + """Crop image patches. + + 3 steps: scale the bboxes -> clip bboxes -> crop and pad. + + Args: + img (ndarray): Image to be cropped. + bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes. + scale (float, optional): Scale ratio of bboxes, the default value + 1.0 means no padding. + pad_fill (Number | list[Number]): Value to be filled for padding. + Default: None, which means no padding. + + Returns: + list[ndarray] | ndarray: The cropped image patches. + """ + chn = 1 if img.ndim == 2 else img.shape[2] + if pad_fill is not None: + if isinstance(pad_fill, (int, float)): + pad_fill = [pad_fill for _ in range(chn)] + assert len(pad_fill) == chn + + _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes + scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32) + clipped_bbox = bbox_clip(scaled_bboxes, img.shape) + + patches = [] + for i in range(clipped_bbox.shape[0]): + x1, y1, x2, y2 = tuple(clipped_bbox[i, :]) + if pad_fill is None: + patch = img[y1:y2 + 1, x1:x2 + 1, ...] + else: + _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :]) + if chn == 1: + patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1) + else: + patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn) + patch = np.array( + pad_fill, dtype=img.dtype) * np.ones( + patch_shape, dtype=img.dtype) + x_start = 0 if _x1 >= 0 else -_x1 + y_start = 0 if _y1 >= 0 else -_y1 + w = x2 - x1 + 1 + h = y2 - y1 + 1 + patch[y_start:y_start + h, x_start:x_start + w, + ...] = img[y1:y1 + h, x1:x1 + w, ...] + patches.append(patch) + + if bboxes.ndim == 1: + return patches[0] + else: + return patches + + +def impad(img, + *, + shape=None, + padding=None, + pad_val=0, + padding_mode='constant'): + """Pad the given image to a certain shape or pad on all sides with + specified padding mode and padding value. + + Args: + img (ndarray): Image to be padded. + shape (tuple[int]): Expected padding shape (h, w). Default: None. + padding (int or tuple[int]): Padding on each border. If a single int is + provided this is used to pad all borders. If tuple of length 2 is + provided this is the padding on left/right and top/bottom + respectively. If a tuple of length 4 is provided this is the + padding for the left, top, right and bottom borders respectively. + Default: None. Note that `shape` and `padding` can not be both + set. + pad_val (Number | Sequence[Number]): Values to be filled in padding + areas when padding_mode is 'constant'. Default: 0. + padding_mode (str): Type of padding. Should be: constant, edge, + reflect or symmetric. Default: constant. + + - constant: pads with a constant value, this value is specified + with pad_val. + - edge: pads with the last value at the edge of the image. + - reflect: pads with reflection of image without repeating the + last value on the edge. For example, padding [1, 2, 3, 4] + with 2 elements on both sides in reflect mode will result + in [3, 2, 1, 2, 3, 4, 3, 2]. + - symmetric: pads with reflection of image repeating the last + value on the edge. For example, padding [1, 2, 3, 4] with + 2 elements on both sides in symmetric mode will result in + [2, 1, 1, 2, 3, 4, 4, 3] + + Returns: + ndarray: The padded image. + """ + + assert (shape is not None) ^ (padding is not None) + if shape is not None: + padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0]) + + # check pad_val + if isinstance(pad_val, tuple): + assert len(pad_val) == img.shape[-1] + elif not isinstance(pad_val, numbers.Number): + raise TypeError('pad_val must be a int or a tuple. ' + f'But received {type(pad_val)}') + + # check padding + if isinstance(padding, tuple) and len(padding) in [2, 4]: + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + elif isinstance(padding, numbers.Number): + padding = (padding, padding, padding, padding) + else: + raise ValueError('Padding must be a int or a 2, or 4 element tuple.' + f'But received {padding}') + + # check padding mode + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + + border_type = { + 'constant': cv2.BORDER_CONSTANT, + 'edge': cv2.BORDER_REPLICATE, + 'reflect': cv2.BORDER_REFLECT_101, + 'symmetric': cv2.BORDER_REFLECT + } + img = cv2.copyMakeBorder( + img, + padding[1], + padding[3], + padding[0], + padding[2], + border_type[padding_mode], + value=pad_val) + + return img + + +def impad_to_multiple(img, divisor, pad_val=0): + """Pad an image to ensure each edge to be multiple to some number. + + Args: + img (ndarray): Image to be padded. + divisor (int): Padded image edges will be multiple to divisor. + pad_val (Number | Sequence[Number]): Same as :func:`impad`. + + Returns: + ndarray: The padded image. + """ + pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor + pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor + return impad(img, shape=(pad_h, pad_w), pad_val=pad_val) + + +def cutout(img, shape, pad_val=0): + """Randomly cut out a rectangle from the original img. + + Args: + img (ndarray): Image to be cutout. + shape (int | tuple[int]): Expected cutout shape (h, w). If given as a + int, the value will be used for both h and w. + pad_val (int | float | tuple[int | float]): Values to be filled in the + cut area. Defaults to 0. + + Returns: + ndarray: The cutout image. + """ + + channels = 1 if img.ndim == 2 else img.shape[2] + if isinstance(shape, int): + cut_h, cut_w = shape, shape + else: + assert isinstance(shape, tuple) and len(shape) == 2, \ + f'shape must be a int or a tuple with length 2, but got type ' \ + f'{type(shape)} instead.' + cut_h, cut_w = shape + if isinstance(pad_val, (int, float)): + pad_val = tuple([pad_val] * channels) + elif isinstance(pad_val, tuple): + assert len(pad_val) == channels, \ + 'Expected the num of elements in tuple equals the channels' \ + 'of input image. Found {} vs {}'.format( + len(pad_val), channels) + else: + raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`') + + img_h, img_w = img.shape[:2] + y0 = np.random.uniform(img_h) + x0 = np.random.uniform(img_w) + + y1 = int(max(0, y0 - cut_h / 2.)) + x1 = int(max(0, x0 - cut_w / 2.)) + y2 = min(img_h, y1 + cut_h) + x2 = min(img_w, x1 + cut_w) + + if img.ndim == 2: + patch_shape = (y2 - y1, x2 - x1) + else: + patch_shape = (y2 - y1, x2 - x1, channels) + + img_cutout = img.copy() + patch = np.array( + pad_val, dtype=img.dtype) * np.ones( + patch_shape, dtype=img.dtype) + img_cutout[y1:y2, x1:x2, ...] = patch + + return img_cutout + + +def _get_shear_matrix(magnitude, direction='horizontal'): + """Generate the shear matrix for transformation. + + Args: + magnitude (int | float): The magnitude used for shear. + direction (str): The flip direction, either "horizontal" + or "vertical". + + Returns: + ndarray: The shear matrix with dtype float32. + """ + if direction == 'horizontal': + shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]]) + elif direction == 'vertical': + shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]]) + return shear_matrix + + +def imshear(img, + magnitude, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Shear an image. + + Args: + img (ndarray): Image to be sheared with format (h, w) + or (h, w, c). + magnitude (int | float): The magnitude used for shear. + direction (str): The flip direction, either "horizontal" + or "vertical". + border_value (int | tuple[int]): Value used in case of a + constant border. + interpolation (str): Same as :func:`resize`. + + Returns: + ndarray: The sheared image. + """ + assert direction in ['horizontal', + 'vertical'], f'Invalid direction: {direction}' + height, width = img.shape[:2] + if img.ndim == 2: + channels = 1 + elif img.ndim == 3: + channels = img.shape[-1] + if isinstance(border_value, int): + border_value = tuple([border_value] * channels) + elif isinstance(border_value, tuple): + assert len(border_value) == channels, \ + 'Expected the num of elements in tuple equals the channels' \ + 'of input image. Found {} vs {}'.format( + len(border_value), channels) + else: + raise ValueError( + f'Invalid type {type(border_value)} for `border_value`') + shear_matrix = _get_shear_matrix(magnitude, direction) + sheared = cv2.warpAffine( + img, + shear_matrix, + (width, height), + # Note case when the number elements in `border_value` + # greater than 3 (e.g. shearing masks whose channels large + # than 3) will raise TypeError in `cv2.warpAffine`. + # Here simply slice the first 3 values in `border_value`. + borderValue=border_value[:3], + flags=cv2_interp_codes[interpolation]) + return sheared + + +def _get_translate_matrix(offset, direction='horizontal'): + """Generate the translate matrix. + + Args: + offset (int | float): The offset used for translate. + direction (str): The translate direction, either + "horizontal" or "vertical". + + Returns: + ndarray: The translate matrix with dtype float32. + """ + if direction == 'horizontal': + translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]]) + elif direction == 'vertical': + translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]]) + return translate_matrix + + +def imtranslate(img, + offset, + direction='horizontal', + border_value=0, + interpolation='bilinear'): + """Translate an image. + + Args: + img (ndarray): Image to be translated with format + (h, w) or (h, w, c). + offset (int | float): The offset used for translate. + direction (str): The translate direction, either "horizontal" + or "vertical". + border_value (int | tuple[int]): Value used in case of a + constant border. + interpolation (str): Same as :func:`resize`. + + Returns: + ndarray: The translated image. + """ + assert direction in ['horizontal', + 'vertical'], f'Invalid direction: {direction}' + height, width = img.shape[:2] + if img.ndim == 2: + channels = 1 + elif img.ndim == 3: + channels = img.shape[-1] + if isinstance(border_value, int): + border_value = tuple([border_value] * channels) + elif isinstance(border_value, tuple): + assert len(border_value) == channels, \ + 'Expected the num of elements in tuple equals the channels' \ + 'of input image. Found {} vs {}'.format( + len(border_value), channels) + else: + raise ValueError( + f'Invalid type {type(border_value)} for `border_value`.') + translate_matrix = _get_translate_matrix(offset, direction) + translated = cv2.warpAffine( + img, + translate_matrix, + (width, height), + # Note case when the number elements in `border_value` + # greater than 3 (e.g. translating masks whose channels + # large than 3) will raise TypeError in `cv2.warpAffine`. + # Here simply slice the first 3 values in `border_value`. + borderValue=border_value[:3], + flags=cv2_interp_codes[interpolation]) + return translated diff --git a/annotator/uniformer/mmcv/image/io.py b/annotator/uniformer/mmcv/image/io.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fa2e8cc06b1a7b0b69de6406980b15d61a1e5d --- /dev/null +++ b/annotator/uniformer/mmcv/image/io.py @@ -0,0 +1,258 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import io +import os.path as osp +from pathlib import Path + +import cv2 +import numpy as np +from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION, + IMREAD_UNCHANGED) + +from annotator.uniformer.mmcv.utils import check_file_exist, is_str, mkdir_or_exist + +try: + from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG +except ImportError: + TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None + +try: + from PIL import Image, ImageOps +except ImportError: + Image = None + +try: + import tifffile +except ImportError: + tifffile = None + +jpeg = None +supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile'] + +imread_flags = { + 'color': IMREAD_COLOR, + 'grayscale': IMREAD_GRAYSCALE, + 'unchanged': IMREAD_UNCHANGED, + 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR, + 'grayscale_ignore_orientation': + IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE +} + +imread_backend = 'cv2' + + +def use_backend(backend): + """Select a backend for image decoding. + + Args: + backend (str): The image decoding backend type. Options are `cv2`, + `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG) + and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg` + file format. + """ + assert backend in supported_backends + global imread_backend + imread_backend = backend + if imread_backend == 'turbojpeg': + if TurboJPEG is None: + raise ImportError('`PyTurboJPEG` is not installed') + global jpeg + if jpeg is None: + jpeg = TurboJPEG() + elif imread_backend == 'pillow': + if Image is None: + raise ImportError('`Pillow` is not installed') + elif imread_backend == 'tifffile': + if tifffile is None: + raise ImportError('`tifffile` is not installed') + + +def _jpegflag(flag='color', channel_order='bgr'): + channel_order = channel_order.lower() + if channel_order not in ['rgb', 'bgr']: + raise ValueError('channel order must be either "rgb" or "bgr"') + + if flag == 'color': + if channel_order == 'bgr': + return TJPF_BGR + elif channel_order == 'rgb': + return TJCS_RGB + elif flag == 'grayscale': + return TJPF_GRAY + else: + raise ValueError('flag must be "color" or "grayscale"') + + +def _pillow2array(img, flag='color', channel_order='bgr'): + """Convert a pillow image to numpy array. + + Args: + img (:obj:`PIL.Image.Image`): The image loaded using PIL + flag (str): Flags specifying the color type of a loaded image, + candidates are 'color', 'grayscale' and 'unchanged'. + Default to 'color'. + channel_order (str): The channel order of the output image array, + candidates are 'bgr' and 'rgb'. Default to 'bgr'. + + Returns: + np.ndarray: The converted numpy array + """ + channel_order = channel_order.lower() + if channel_order not in ['rgb', 'bgr']: + raise ValueError('channel order must be either "rgb" or "bgr"') + + if flag == 'unchanged': + array = np.array(img) + if array.ndim >= 3 and array.shape[2] >= 3: # color image + array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR + else: + # Handle exif orientation tag + if flag in ['color', 'grayscale']: + img = ImageOps.exif_transpose(img) + # If the image mode is not 'RGB', convert it to 'RGB' first. + if img.mode != 'RGB': + if img.mode != 'LA': + # Most formats except 'LA' can be directly converted to RGB + img = img.convert('RGB') + else: + # When the mode is 'LA', the default conversion will fill in + # the canvas with black, which sometimes shadows black objects + # in the foreground. + # + # Therefore, a random color (124, 117, 104) is used for canvas + img_rgba = img.convert('RGBA') + img = Image.new('RGB', img_rgba.size, (124, 117, 104)) + img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha + if flag in ['color', 'color_ignore_orientation']: + array = np.array(img) + if channel_order != 'rgb': + array = array[:, :, ::-1] # RGB to BGR + elif flag in ['grayscale', 'grayscale_ignore_orientation']: + img = img.convert('L') + array = np.array(img) + else: + raise ValueError( + 'flag must be "color", "grayscale", "unchanged", ' + f'"color_ignore_orientation" or "grayscale_ignore_orientation"' + f' but got {flag}') + return array + + +def imread(img_or_path, flag='color', channel_order='bgr', backend=None): + """Read an image. + + Args: + img_or_path (ndarray or str or Path): Either a numpy array or str or + pathlib.Path. If it is a numpy array (loaded image), then + it will be returned as is. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale`, `unchanged`, + `color_ignore_orientation` and `grayscale_ignore_orientation`. + By default, `cv2` and `pillow` backend would rotate the image + according to its EXIF info unless called with `unchanged` or + `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend + always ignore image's EXIF info regardless of the flag. + The `turbojpeg` backend only supports `color` and `grayscale`. + channel_order (str): Order of channel, candidates are `bgr` and `rgb`. + backend (str | None): The image decoding backend type. Options are + `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. + If backend is None, the global imread_backend specified by + ``mmcv.use_backend()`` will be used. Default: None. + + Returns: + ndarray: Loaded image array. + """ + + if backend is None: + backend = imread_backend + if backend not in supported_backends: + raise ValueError(f'backend: {backend} is not supported. Supported ' + "backends are 'cv2', 'turbojpeg', 'pillow'") + if isinstance(img_or_path, Path): + img_or_path = str(img_or_path) + + if isinstance(img_or_path, np.ndarray): + return img_or_path + elif is_str(img_or_path): + check_file_exist(img_or_path, + f'img file does not exist: {img_or_path}') + if backend == 'turbojpeg': + with open(img_or_path, 'rb') as in_file: + img = jpeg.decode(in_file.read(), + _jpegflag(flag, channel_order)) + if img.shape[-1] == 1: + img = img[:, :, 0] + return img + elif backend == 'pillow': + img = Image.open(img_or_path) + img = _pillow2array(img, flag, channel_order) + return img + elif backend == 'tifffile': + img = tifffile.imread(img_or_path) + return img + else: + flag = imread_flags[flag] if is_str(flag) else flag + img = cv2.imread(img_or_path, flag) + if flag == IMREAD_COLOR and channel_order == 'rgb': + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) + return img + else: + raise TypeError('"img" must be a numpy array or a str or ' + 'a pathlib.Path object') + + +def imfrombytes(content, flag='color', channel_order='bgr', backend=None): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Same as :func:`imread`. + backend (str | None): The image decoding backend type. Options are + `cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the + global imread_backend specified by ``mmcv.use_backend()`` will be + used. Default: None. + + Returns: + ndarray: Loaded image array. + """ + + if backend is None: + backend = imread_backend + if backend not in supported_backends: + raise ValueError(f'backend: {backend} is not supported. Supported ' + "backends are 'cv2', 'turbojpeg', 'pillow'") + if backend == 'turbojpeg': + img = jpeg.decode(content, _jpegflag(flag, channel_order)) + if img.shape[-1] == 1: + img = img[:, :, 0] + return img + elif backend == 'pillow': + buff = io.BytesIO(content) + img = Image.open(buff) + img = _pillow2array(img, flag, channel_order) + return img + else: + img_np = np.frombuffer(content, np.uint8) + flag = imread_flags[flag] if is_str(flag) else flag + img = cv2.imdecode(img_np, flag) + if flag == IMREAD_COLOR and channel_order == 'rgb': + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = osp.abspath(osp.dirname(file_path)) + mkdir_or_exist(dir_name) + return cv2.imwrite(file_path, img, params) diff --git a/annotator/uniformer/mmcv/image/misc.py b/annotator/uniformer/mmcv/image/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..3e61f05e3b05e4c7b40de4eb6c8eb100e6da41d0 --- /dev/null +++ b/annotator/uniformer/mmcv/image/misc.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +import annotator.uniformer.mmcv as mmcv + +try: + import torch +except ImportError: + torch = None + + +def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True): + """Convert tensor to 3-channel images. + + Args: + tensor (torch.Tensor): Tensor that contains multiple images, shape ( + N, C, H, W). + mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0). + std (tuple[float], optional): Standard deviation of images. + Defaults to (1, 1, 1). + to_rgb (bool, optional): Whether the tensor was converted to RGB + format in the first place. If so, convert it back to BGR. + Defaults to True. + + Returns: + list[np.ndarray]: A list that contains multiple images. + """ + + if torch is None: + raise RuntimeError('pytorch is not installed') + assert torch.is_tensor(tensor) and tensor.ndim == 4 + assert len(mean) == 3 + assert len(std) == 3 + + num_imgs = tensor.size(0) + mean = np.array(mean, dtype=np.float32) + std = np.array(std, dtype=np.float32) + imgs = [] + for img_id in range(num_imgs): + img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0) + img = mmcv.imdenormalize( + img, mean, std, to_bgr=to_rgb).astype(np.uint8) + imgs.append(np.ascontiguousarray(img)) + return imgs diff --git a/annotator/uniformer/mmcv/image/photometric.py b/annotator/uniformer/mmcv/image/photometric.py new file mode 100644 index 0000000000000000000000000000000000000000..5085d012019c0cbf56f66f421a378278c1a058ae --- /dev/null +++ b/annotator/uniformer/mmcv/image/photometric.py @@ -0,0 +1,428 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np + +from ..utils import is_tuple_of +from .colorspace import bgr2gray, gray2bgr + + +def imnormalize(img, mean, std, to_rgb=True): + """Normalize an image with mean and std. + + Args: + img (ndarray): Image to be normalized. + mean (ndarray): The mean to be used for normalize. + std (ndarray): The std to be used for normalize. + to_rgb (bool): Whether to convert to rgb. + + Returns: + ndarray: The normalized image. + """ + img = img.copy().astype(np.float32) + return imnormalize_(img, mean, std, to_rgb) + + +def imnormalize_(img, mean, std, to_rgb=True): + """Inplace normalize an image with mean and std. + + Args: + img (ndarray): Image to be normalized. + mean (ndarray): The mean to be used for normalize. + std (ndarray): The std to be used for normalize. + to_rgb (bool): Whether to convert to rgb. + + Returns: + ndarray: The normalized image. + """ + # cv2 inplace normalization does not accept uint8 + assert img.dtype != np.uint8 + mean = np.float64(mean.reshape(1, -1)) + stdinv = 1 / np.float64(std.reshape(1, -1)) + if to_rgb: + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace + cv2.subtract(img, mean, img) # inplace + cv2.multiply(img, stdinv, img) # inplace + return img + + +def imdenormalize(img, mean, std, to_bgr=True): + assert img.dtype != np.uint8 + mean = mean.reshape(1, -1).astype(np.float64) + std = std.reshape(1, -1).astype(np.float64) + img = cv2.multiply(img, std) # make a copy + cv2.add(img, mean, img) # inplace + if to_bgr: + cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace + return img + + +def iminvert(img): + """Invert (negate) an image. + + Args: + img (ndarray): Image to be inverted. + + Returns: + ndarray: The inverted image. + """ + return np.full_like(img, 255) - img + + +def solarize(img, thr=128): + """Solarize an image (invert all pixel values above a threshold) + + Args: + img (ndarray): Image to be solarized. + thr (int): Threshold for solarizing (0 - 255). + + Returns: + ndarray: The solarized image. + """ + img = np.where(img < thr, img, 255 - img) + return img + + +def posterize(img, bits): + """Posterize an image (reduce the number of bits for each color channel) + + Args: + img (ndarray): Image to be posterized. + bits (int): Number of bits (1 to 8) to use for posterizing. + + Returns: + ndarray: The posterized image. + """ + shift = 8 - bits + img = np.left_shift(np.right_shift(img, shift), shift) + return img + + +def adjust_color(img, alpha=1, beta=None, gamma=0): + r"""It blends the source image and its gray image: + + .. math:: + output = img * alpha + gray\_img * beta + gamma + + Args: + img (ndarray): The input source image. + alpha (int | float): Weight for the source image. Default 1. + beta (int | float): Weight for the converted gray image. + If None, it's assigned the value (1 - `alpha`). + gamma (int | float): Scalar added to each sum. + Same as :func:`cv2.addWeighted`. Default 0. + + Returns: + ndarray: Colored image which has the same size and dtype as input. + """ + gray_img = bgr2gray(img) + gray_img = np.tile(gray_img[..., None], [1, 1, 3]) + if beta is None: + beta = 1 - alpha + colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma) + if not colored_img.dtype == np.uint8: + # Note when the dtype of `img` is not the default `np.uint8` + # (e.g. np.float32), the value in `colored_img` got from cv2 + # is not guaranteed to be in range [0, 255], so here clip + # is needed. + colored_img = np.clip(colored_img, 0, 255) + return colored_img + + +def imequalize(img): + """Equalize the image histogram. + + This function applies a non-linear mapping to the input image, + in order to create a uniform distribution of grayscale values + in the output image. + + Args: + img (ndarray): Image to be equalized. + + Returns: + ndarray: The equalized image. + """ + + def _scale_channel(im, c): + """Scale the data in the corresponding channel.""" + im = im[:, :, c] + # Compute the histogram of the image channel. + histo = np.histogram(im, 256, (0, 255))[0] + # For computing the step, filter out the nonzeros. + nonzero_histo = histo[histo > 0] + step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255 + if not step: + lut = np.array(range(256)) + else: + # Compute the cumulative sum, shifted by step // 2 + # and then normalized by step. + lut = (np.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = np.concatenate([[0], lut[:-1]], 0) + # handle potential integer overflow + lut[lut > 255] = 255 + # If step is zero, return the original image. + # Otherwise, index from lut. + return np.where(np.equal(step, 0), im, lut[im]) + + # Scales each channel independently and then stacks + # the result. + s1 = _scale_channel(img, 0) + s2 = _scale_channel(img, 1) + s3 = _scale_channel(img, 2) + equalized_img = np.stack([s1, s2, s3], axis=-1) + return equalized_img.astype(img.dtype) + + +def adjust_brightness(img, factor=1.): + """Adjust image brightness. + + This function controls the brightness of an image. An + enhancement factor of 0.0 gives a black image. + A factor of 1.0 gives the original image. This function + blends the source image and the degenerated black image: + + .. math:: + output = img * factor + degenerated * (1 - factor) + + Args: + img (ndarray): Image to be brightened. + factor (float): A value controls the enhancement. + Factor 1.0 returns the original image, lower + factors mean less color (brightness, contrast, + etc), and higher values more. Default 1. + + Returns: + ndarray: The brightened image. + """ + degenerated = np.zeros_like(img) + # Note manually convert the dtype to np.float32, to + # achieve as close results as PIL.ImageEnhance.Brightness. + # Set beta=1-factor, and gamma=0 + brightened_img = cv2.addWeighted( + img.astype(np.float32), factor, degenerated.astype(np.float32), + 1 - factor, 0) + brightened_img = np.clip(brightened_img, 0, 255) + return brightened_img.astype(img.dtype) + + +def adjust_contrast(img, factor=1.): + """Adjust image contrast. + + This function controls the contrast of an image. An + enhancement factor of 0.0 gives a solid grey + image. A factor of 1.0 gives the original image. It + blends the source image and the degenerated mean image: + + .. math:: + output = img * factor + degenerated * (1 - factor) + + Args: + img (ndarray): Image to be contrasted. BGR order. + factor (float): Same as :func:`mmcv.adjust_brightness`. + + Returns: + ndarray: The contrasted image. + """ + gray_img = bgr2gray(img) + hist = np.histogram(gray_img, 256, (0, 255))[0] + mean = round(np.sum(gray_img) / np.sum(hist)) + degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype) + degenerated = gray2bgr(degenerated) + contrasted_img = cv2.addWeighted( + img.astype(np.float32), factor, degenerated.astype(np.float32), + 1 - factor, 0) + contrasted_img = np.clip(contrasted_img, 0, 255) + return contrasted_img.astype(img.dtype) + + +def auto_contrast(img, cutoff=0): + """Auto adjust image contrast. + + This function maximize (normalize) image contrast by first removing cutoff + percent of the lightest and darkest pixels from the histogram and remapping + the image so that the darkest pixel becomes black (0), and the lightest + becomes white (255). + + Args: + img (ndarray): Image to be contrasted. BGR order. + cutoff (int | float | tuple): The cutoff percent of the lightest and + darkest pixels to be removed. If given as tuple, it shall be + (low, high). Otherwise, the single value will be used for both. + Defaults to 0. + + Returns: + ndarray: The contrasted image. + """ + + def _auto_contrast_channel(im, c, cutoff): + im = im[:, :, c] + # Compute the histogram of the image channel. + histo = np.histogram(im, 256, (0, 255))[0] + # Remove cut-off percent pixels from histo + histo_sum = np.cumsum(histo) + cut_low = histo_sum[-1] * cutoff[0] // 100 + cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100 + histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low + histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0) + + # Compute mapping + low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1] + # If all the values have been cut off, return the origin img + if low >= high: + return im + scale = 255.0 / (high - low) + offset = -low * scale + lut = np.array(range(256)) + lut = lut * scale + offset + lut = np.clip(lut, 0, 255) + return lut[im] + + if isinstance(cutoff, (int, float)): + cutoff = (cutoff, cutoff) + else: + assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \ + f'float or tuple, but got {type(cutoff)} instead.' + # Auto adjusts contrast for each channel independently and then stacks + # the result. + s1 = _auto_contrast_channel(img, 0, cutoff) + s2 = _auto_contrast_channel(img, 1, cutoff) + s3 = _auto_contrast_channel(img, 2, cutoff) + contrasted_img = np.stack([s1, s2, s3], axis=-1) + return contrasted_img.astype(img.dtype) + + +def adjust_sharpness(img, factor=1., kernel=None): + """Adjust image sharpness. + + This function controls the sharpness of an image. An + enhancement factor of 0.0 gives a blurred image. A + factor of 1.0 gives the original image. And a factor + of 2.0 gives a sharpened image. It blends the source + image and the degenerated mean image: + + .. math:: + output = img * factor + degenerated * (1 - factor) + + Args: + img (ndarray): Image to be sharpened. BGR order. + factor (float): Same as :func:`mmcv.adjust_brightness`. + kernel (np.ndarray, optional): Filter kernel to be applied on the img + to obtain the degenerated img. Defaults to None. + + Note: + No value sanity check is enforced on the kernel set by users. So with + an inappropriate kernel, the ``adjust_sharpness`` may fail to perform + the function its name indicates but end up performing whatever + transform determined by the kernel. + + Returns: + ndarray: The sharpened image. + """ + + if kernel is None: + # adopted from PIL.ImageFilter.SMOOTH + kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13 + assert isinstance(kernel, np.ndarray), \ + f'kernel must be of type np.ndarray, but got {type(kernel)} instead.' + assert kernel.ndim == 2, \ + f'kernel must have a dimension of 2, but got {kernel.ndim} instead.' + + degenerated = cv2.filter2D(img, -1, kernel) + sharpened_img = cv2.addWeighted( + img.astype(np.float32), factor, degenerated.astype(np.float32), + 1 - factor, 0) + sharpened_img = np.clip(sharpened_img, 0, 255) + return sharpened_img.astype(img.dtype) + + +def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True): + """AlexNet-style PCA jitter. + + This data augmentation is proposed in `ImageNet Classification with Deep + Convolutional Neural Networks + `_. + + Args: + img (ndarray): Image to be adjusted lighting. BGR order. + eigval (ndarray): the eigenvalue of the convariance matrix of pixel + values, respectively. + eigvec (ndarray): the eigenvector of the convariance matrix of pixel + values, respectively. + alphastd (float): The standard deviation for distribution of alpha. + Defaults to 0.1 + to_rgb (bool): Whether to convert img to rgb. + + Returns: + ndarray: The adjusted image. + """ + assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \ + f'eigval and eigvec should both be of type np.ndarray, got ' \ + f'{type(eigval)} and {type(eigvec)} instead.' + + assert eigval.ndim == 1 and eigvec.ndim == 2 + assert eigvec.shape == (3, eigval.shape[0]) + n_eigval = eigval.shape[0] + assert isinstance(alphastd, float), 'alphastd should be of type float, ' \ + f'got {type(alphastd)} instead.' + + img = img.copy().astype(np.float32) + if to_rgb: + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace + + alpha = np.random.normal(0, alphastd, n_eigval) + alter = eigvec \ + * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \ + * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval)) + alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape) + img_adjusted = img + alter + return img_adjusted + + +def lut_transform(img, lut_table): + """Transform array by look-up table. + + The function lut_transform fills the output array with values from the + look-up table. Indices of the entries are taken from the input array. + + Args: + img (ndarray): Image to be transformed. + lut_table (ndarray): look-up table of 256 elements; in case of + multi-channel input array, the table should either have a single + channel (in this case the same table is used for all channels) or + the same number of channels as in the input array. + + Returns: + ndarray: The transformed image. + """ + assert isinstance(img, np.ndarray) + assert 0 <= np.min(img) and np.max(img) <= 255 + assert isinstance(lut_table, np.ndarray) + assert lut_table.shape == (256, ) + + return cv2.LUT(np.array(img, dtype=np.uint8), lut_table) + + +def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Args: + img (ndarray): Image to be processed. + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + + Returns: + ndarray: The processed image. + """ + assert isinstance(img, np.ndarray) + assert img.ndim == 2 + assert isinstance(clip_limit, (float, int)) + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + + clahe = cv2.createCLAHE(clip_limit, tile_grid_size) + return clahe.apply(np.array(img, dtype=np.uint8)) diff --git a/annotator/uniformer/mmcv/model_zoo/deprecated.json b/annotator/uniformer/mmcv/model_zoo/deprecated.json new file mode 100644 index 0000000000000000000000000000000000000000..25cf6f28caecc22a77e3136fefa6b8dfc0e6cb5b --- /dev/null +++ b/annotator/uniformer/mmcv/model_zoo/deprecated.json @@ -0,0 +1,6 @@ +{ + "resnet50_caffe": "detectron/resnet50_caffe", + "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr", + "resnet101_caffe": "detectron/resnet101_caffe", + "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr" +} diff --git a/annotator/uniformer/mmcv/model_zoo/mmcls.json b/annotator/uniformer/mmcv/model_zoo/mmcls.json new file mode 100644 index 0000000000000000000000000000000000000000..bdb311d9fe6d9f317290feedc9e37236c6cf6e8f --- /dev/null +++ b/annotator/uniformer/mmcv/model_zoo/mmcls.json @@ -0,0 +1,31 @@ +{ + "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth", + "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth", + "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth", + "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth", + "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth", + "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth", + "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth", + "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth", + "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.pth", + "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.pth", + "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth", + "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth", + "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth", + "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_batch256_imagenet_20200708-1ad0ce94.pth", + "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_batch256_imagenet_20200708-9cb302ef.pth", + "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_batch256_imagenet_20200708-e79cb6a2.pth", + "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth", + "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth", + "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth", + "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth", + "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth", + "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth", + "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth", + "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth", + "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth", + "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth", + "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth", + "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth", + "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth" +} diff --git a/annotator/uniformer/mmcv/model_zoo/open_mmlab.json b/annotator/uniformer/mmcv/model_zoo/open_mmlab.json new file mode 100644 index 0000000000000000000000000000000000000000..8311db4feef92faa0841c697d75efbee8430c3a0 --- /dev/null +++ b/annotator/uniformer/mmcv/model_zoo/open_mmlab.json @@ -0,0 +1,50 @@ +{ + "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth", + "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth", + "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth", + "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth", + "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth", + "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth", + "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth", + "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth", + "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth", + "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth", + "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth", + "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth", + "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth", + "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth", + "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth", + "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth", + "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth", + "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth", + "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth", + "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth", + "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth", + "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth", + "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth", + "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth", + "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth", + "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth", + "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth", + "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth", + "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth", + "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth", + "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth", + "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth", + "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth", + "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth", + "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth", + "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth", + "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth", + "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth", + "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth", + "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth", + "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth", + "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth", + "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth", + "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth", + "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth", + "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth", + "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth", + "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth" +} diff --git a/annotator/uniformer/mmcv/ops/__init__.py b/annotator/uniformer/mmcv/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..999e090a458ee148ceca0649f1e3806a40e909bd --- /dev/null +++ b/annotator/uniformer/mmcv/ops/__init__.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .assign_score_withk import assign_score_withk +from .ball_query import ball_query +from .bbox import bbox_overlaps +from .border_align import BorderAlign, border_align +from .box_iou_rotated import box_iou_rotated +from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive +from .cc_attention import CrissCrossAttention +from .contour_expand import contour_expand +from .corner_pool import CornerPool +from .correlation import Correlation +from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d +from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack, + ModulatedDeformRoIPoolPack, deform_roi_pool) +from .deprecated_wrappers import Conv2d_deprecated as Conv2d +from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d +from .deprecated_wrappers import Linear_deprecated as Linear +from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d +from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, + sigmoid_focal_loss, softmax_focal_loss) +from .furthest_point_sample import (furthest_point_sample, + furthest_point_sample_with_dist) +from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu +from .gather_points import gather_points +from .group_points import GroupAll, QueryAndGroup, grouping_operation +from .info import (get_compiler_version, get_compiling_cuda_version, + get_onnxruntime_op_path) +from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev +from .knn import knn +from .masked_conv import MaskedConv2d, masked_conv2d +from .modulated_deform_conv import (ModulatedDeformConv2d, + ModulatedDeformConv2dPack, + modulated_deform_conv2d) +from .multi_scale_deform_attn import MultiScaleDeformableAttention +from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms +from .pixel_group import pixel_group +from .point_sample import (SimpleRoIAlign, point_sample, + rel_roi_point_to_rel_img_point) +from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu, + points_in_boxes_part) +from .points_sampler import PointsSampler +from .psa_mask import PSAMask +from .roi_align import RoIAlign, roi_align +from .roi_align_rotated import RoIAlignRotated, roi_align_rotated +from .roi_pool import RoIPool, roi_pool +from .roiaware_pool3d import RoIAwarePool3d +from .roipoint_pool3d import RoIPointPool3d +from .saconv import SAConv2d +from .scatter_points import DynamicScatter, dynamic_scatter +from .sync_bn import SyncBatchNorm +from .three_interpolate import three_interpolate +from .three_nn import three_nn +from .tin_shift import TINShift, tin_shift +from .upfirdn2d import upfirdn2d +from .voxelize import Voxelization, voxelization + +__all__ = [ + 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', + 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack', + 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack', + 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss', + 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss', + 'get_compiler_version', 'get_compiling_cuda_version', + 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d', + 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack', + 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match', + 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', + 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', + 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', + 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', + 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', + 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', + 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup', + 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn', + 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', + 'border_align', 'gather_points', 'furthest_point_sample', + 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation', + 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization', + 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', + 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all' +] diff --git a/annotator/uniformer/mmcv/ops/assign_score_withk.py b/annotator/uniformer/mmcv/ops/assign_score_withk.py new file mode 100644 index 0000000000000000000000000000000000000000..4906adaa2cffd1b46912fbe7d4f87ef2f9fa0012 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/assign_score_withk.py @@ -0,0 +1,123 @@ +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['assign_score_withk_forward', 'assign_score_withk_backward']) + + +class AssignScoreWithK(Function): + r"""Perform weighted sum to generate output features according to scores. + Modified from `PAConv `_. + + This is a memory-efficient CUDA implementation of assign_scores operation, + which first transform all point features with weight bank, then assemble + neighbor features with ``knn_idx`` and perform weighted sum of ``scores``. + + See the `paper `_ appendix Sec. D for + more detailed descriptions. + + Note: + This implementation assumes using ``neighbor`` kernel input, which is + (point_features - center_features, point_features). + See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/ + pointnet2/paconv.py#L128 for more details. + """ + + @staticmethod + def forward(ctx, + scores, + point_features, + center_features, + knn_idx, + aggregate='sum'): + """ + Args: + scores (torch.Tensor): (B, npoint, K, M), predicted scores to + aggregate weight matrices in the weight bank. + ``npoint`` is the number of sampled centers. + ``K`` is the number of queried neighbors. + ``M`` is the number of weight matrices in the weight bank. + point_features (torch.Tensor): (B, N, M, out_dim) + Pre-computed point features to be aggregated. + center_features (torch.Tensor): (B, N, M, out_dim) + Pre-computed center features to be aggregated. + knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN. + We assume the first idx in each row is the idx of the center. + aggregate (str, optional): Aggregation method. + Can be 'sum', 'avg' or 'max'. Defaults: 'sum'. + + Returns: + torch.Tensor: (B, out_dim, npoint, K), the aggregated features. + """ + agg = {'sum': 0, 'avg': 1, 'max': 2} + + B, N, M, out_dim = point_features.size() + _, npoint, K, _ = scores.size() + + output = point_features.new_zeros((B, out_dim, npoint, K)) + ext_module.assign_score_withk_forward( + point_features.contiguous(), + center_features.contiguous(), + scores.contiguous(), + knn_idx.contiguous(), + output, + B=B, + N0=N, + N1=npoint, + M=M, + K=K, + O=out_dim, + aggregate=agg[aggregate]) + + ctx.save_for_backward(output, point_features, center_features, scores, + knn_idx) + ctx.agg = agg[aggregate] + + return output + + @staticmethod + def backward(ctx, grad_out): + """ + Args: + grad_out (torch.Tensor): (B, out_dim, npoint, K) + + Returns: + grad_scores (torch.Tensor): (B, npoint, K, M) + grad_point_features (torch.Tensor): (B, N, M, out_dim) + grad_center_features (torch.Tensor): (B, N, M, out_dim) + """ + _, point_features, center_features, scores, knn_idx = ctx.saved_tensors + + agg = ctx.agg + + B, N, M, out_dim = point_features.size() + _, npoint, K, _ = scores.size() + + grad_point_features = point_features.new_zeros(point_features.shape) + grad_center_features = center_features.new_zeros(center_features.shape) + grad_scores = scores.new_zeros(scores.shape) + + ext_module.assign_score_withk_backward( + grad_out.contiguous(), + point_features.contiguous(), + center_features.contiguous(), + scores.contiguous(), + knn_idx.contiguous(), + grad_point_features, + grad_center_features, + grad_scores, + B=B, + N0=N, + N1=npoint, + M=M, + K=K, + O=out_dim, + aggregate=agg) + + return grad_scores, grad_point_features, \ + grad_center_features, None, None + + +assign_score_withk = AssignScoreWithK.apply diff --git a/annotator/uniformer/mmcv/ops/ball_query.py b/annotator/uniformer/mmcv/ops/ball_query.py new file mode 100644 index 0000000000000000000000000000000000000000..d0466847c6e5c1239e359a0397568413ebc1504a --- /dev/null +++ b/annotator/uniformer/mmcv/ops/ball_query.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['ball_query_forward']) + + +class BallQuery(Function): + """Find nearby points in spherical space.""" + + @staticmethod + def forward(ctx, min_radius: float, max_radius: float, sample_num: int, + xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor: + """ + Args: + min_radius (float): minimum radius of the balls. + max_radius (float): maximum radius of the balls. + sample_num (int): maximum number of features in the balls. + xyz (Tensor): (B, N, 3) xyz coordinates of the features. + center_xyz (Tensor): (B, npoint, 3) centers of the ball query. + + Returns: + Tensor: (B, npoint, nsample) tensor with the indices of + the features that form the query balls. + """ + assert center_xyz.is_contiguous() + assert xyz.is_contiguous() + assert min_radius < max_radius + + B, N, _ = xyz.size() + npoint = center_xyz.size(1) + idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int) + + ext_module.ball_query_forward( + center_xyz, + xyz, + idx, + b=B, + n=N, + m=npoint, + min_radius=min_radius, + max_radius=max_radius, + nsample=sample_num) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None, None + + +ball_query = BallQuery.apply diff --git a/annotator/uniformer/mmcv/ops/bbox.py b/annotator/uniformer/mmcv/ops/bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..0c4d58b6c91f652933974f519acd3403a833e906 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/bbox.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['bbox_overlaps']) + + +def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0): + """Calculate overlap between two set of bboxes. + + If ``aligned`` is ``False``, then calculate the ious between each bbox + of bboxes1 and bboxes2, otherwise the ious between each aligned pair of + bboxes1 and bboxes2. + + Args: + bboxes1 (Tensor): shape (m, 4) in format or empty. + bboxes2 (Tensor): shape (n, 4) in format or empty. + If aligned is ``True``, then m and n must be equal. + mode (str): "iou" (intersection over union) or iof (intersection over + foreground). + + Returns: + ious(Tensor): shape (m, n) if aligned == False else shape (m, 1) + + Example: + >>> bboxes1 = torch.FloatTensor([ + >>> [0, 0, 10, 10], + >>> [10, 10, 20, 20], + >>> [32, 32, 38, 42], + >>> ]) + >>> bboxes2 = torch.FloatTensor([ + >>> [0, 0, 10, 20], + >>> [0, 10, 10, 19], + >>> [10, 10, 20, 20], + >>> ]) + >>> bbox_overlaps(bboxes1, bboxes2) + tensor([[0.5000, 0.0000, 0.0000], + [0.0000, 0.0000, 1.0000], + [0.0000, 0.0000, 0.0000]]) + + Example: + >>> empty = torch.FloatTensor([]) + >>> nonempty = torch.FloatTensor([ + >>> [0, 0, 10, 9], + >>> ]) + >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) + >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) + >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) + """ + + mode_dict = {'iou': 0, 'iof': 1} + assert mode in mode_dict.keys() + mode_flag = mode_dict[mode] + # Either the boxes are empty or the length of boxes' last dimension is 4 + assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0) + assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0) + assert offset == 1 or offset == 0 + + rows = bboxes1.size(0) + cols = bboxes2.size(0) + if aligned: + assert rows == cols + + if rows * cols == 0: + return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols) + + if aligned: + ious = bboxes1.new_zeros(rows) + else: + ious = bboxes1.new_zeros((rows, cols)) + ext_module.bbox_overlaps( + bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset) + return ious diff --git a/annotator/uniformer/mmcv/ops/border_align.py b/annotator/uniformer/mmcv/ops/border_align.py new file mode 100644 index 0000000000000000000000000000000000000000..ff305be328e9b0a15e1bbb5e6b41beb940f55c81 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/border_align.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modified from +# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['border_align_forward', 'border_align_backward']) + + +class BorderAlignFunction(Function): + + @staticmethod + def symbolic(g, input, boxes, pool_size): + return g.op( + 'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size) + + @staticmethod + def forward(ctx, input, boxes, pool_size): + ctx.pool_size = pool_size + ctx.input_shape = input.size() + + assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]' + assert boxes.size(2) == 4, \ + 'the last dimension of boxes must be (x1, y1, x2, y2)' + assert input.size(1) % 4 == 0, \ + 'the channel for input feature must be divisible by factor 4' + + # [B, C//4, H*W, 4] + output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4) + output = input.new_zeros(output_shape) + # `argmax_idx` only used for backward + argmax_idx = input.new_zeros(output_shape).to(torch.int) + + ext_module.border_align_forward( + input, boxes, output, argmax_idx, pool_size=ctx.pool_size) + + ctx.save_for_backward(boxes, argmax_idx) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + boxes, argmax_idx = ctx.saved_tensors + grad_input = grad_output.new_zeros(ctx.input_shape) + # complex head architecture may cause grad_output uncontiguous + grad_output = grad_output.contiguous() + ext_module.border_align_backward( + grad_output, + boxes, + argmax_idx, + grad_input, + pool_size=ctx.pool_size) + return grad_input, None, None + + +border_align = BorderAlignFunction.apply + + +class BorderAlign(nn.Module): + r"""Border align pooling layer. + + Applies border_align over the input feature based on predicted bboxes. + The details were described in the paper + `BorderDet: Border Feature for Dense Object Detection + `_. + + For each border line (e.g. top, left, bottom or right) of each box, + border_align does the following: + 1. uniformly samples `pool_size`+1 positions on this line, involving \ + the start and end points. + 2. the corresponding features on these points are computed by \ + bilinear interpolation. + 3. max pooling over all the `pool_size`+1 positions are used for \ + computing pooled feature. + + Args: + pool_size (int): number of positions sampled over the boxes' borders + (e.g. top, bottom, left, right). + + """ + + def __init__(self, pool_size): + super(BorderAlign, self).__init__() + self.pool_size = pool_size + + def forward(self, input, boxes): + """ + Args: + input: Features with shape [N,4C,H,W]. Channels ranged in [0,C), + [C,2C), [2C,3C), [3C,4C) represent the top, left, bottom, + right features respectively. + boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2). + + Returns: + Tensor: Pooled features with shape [N,C,H*W,4]. The order is + (top,left,bottom,right) for the last dimension. + """ + return border_align(input, boxes, self.pool_size) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(pool_size={self.pool_size})' + return s diff --git a/annotator/uniformer/mmcv/ops/box_iou_rotated.py b/annotator/uniformer/mmcv/ops/box_iou_rotated.py new file mode 100644 index 0000000000000000000000000000000000000000..2d78015e9c2a9e7a52859b4e18f84a9aa63481a0 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/box_iou_rotated.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated']) + + +def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False): + """Return intersection-over-union (Jaccard index) of boxes. + + Both sets of boxes are expected to be in + (x_center, y_center, width, height, angle) format. + + If ``aligned`` is ``False``, then calculate the ious between each bbox + of bboxes1 and bboxes2, otherwise the ious between each aligned pair of + bboxes1 and bboxes2. + + Arguments: + boxes1 (Tensor): rotated bboxes 1. \ + It has shape (N, 5), indicating (x, y, w, h, theta) for each row. + Note that theta is in radian. + boxes2 (Tensor): rotated bboxes 2. \ + It has shape (M, 5), indicating (x, y, w, h, theta) for each row. + Note that theta is in radian. + mode (str): "iou" (intersection over union) or iof (intersection over + foreground). + + Returns: + ious(Tensor): shape (N, M) if aligned == False else shape (N,) + """ + assert mode in ['iou', 'iof'] + mode_dict = {'iou': 0, 'iof': 1} + mode_flag = mode_dict[mode] + rows = bboxes1.size(0) + cols = bboxes2.size(0) + if aligned: + ious = bboxes1.new_zeros(rows) + else: + ious = bboxes1.new_zeros((rows * cols)) + bboxes1 = bboxes1.contiguous() + bboxes2 = bboxes2.contiguous() + ext_module.box_iou_rotated( + bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned) + if not aligned: + ious = ious.view(rows, cols) + return ious diff --git a/annotator/uniformer/mmcv/ops/carafe.py b/annotator/uniformer/mmcv/ops/carafe.py new file mode 100644 index 0000000000000000000000000000000000000000..5154cb3abfccfbbe0a1b2daa67018dbf80aaf6d2 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/carafe.py @@ -0,0 +1,287 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.nn.modules.module import Module + +from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'carafe_naive_forward', 'carafe_naive_backward', 'carafe_forward', + 'carafe_backward' +]) + + +class CARAFENaiveFunction(Function): + + @staticmethod + def symbolic(g, features, masks, kernel_size, group_size, scale_factor): + return g.op( + 'mmcv::MMCVCARAFENaive', + features, + masks, + kernel_size_i=kernel_size, + group_size_i=group_size, + scale_factor_f=scale_factor) + + @staticmethod + def forward(ctx, features, masks, kernel_size, group_size, scale_factor): + assert scale_factor >= 1 + assert masks.size(1) == kernel_size * kernel_size * group_size + assert masks.size(-1) == features.size(-1) * scale_factor + assert masks.size(-2) == features.size(-2) * scale_factor + assert features.size(1) % group_size == 0 + assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1 + ctx.kernel_size = kernel_size + ctx.group_size = group_size + ctx.scale_factor = scale_factor + ctx.feature_size = features.size() + ctx.mask_size = masks.size() + + n, c, h, w = features.size() + output = features.new_zeros((n, c, h * scale_factor, w * scale_factor)) + ext_module.carafe_naive_forward( + features, + masks, + output, + kernel_size=kernel_size, + group_size=group_size, + scale_factor=scale_factor) + + if features.requires_grad or masks.requires_grad: + ctx.save_for_backward(features, masks) + return output + + @staticmethod + def backward(ctx, grad_output): + assert grad_output.is_cuda + + features, masks = ctx.saved_tensors + kernel_size = ctx.kernel_size + group_size = ctx.group_size + scale_factor = ctx.scale_factor + + grad_input = torch.zeros_like(features) + grad_masks = torch.zeros_like(masks) + ext_module.carafe_naive_backward( + grad_output.contiguous(), + features, + masks, + grad_input, + grad_masks, + kernel_size=kernel_size, + group_size=group_size, + scale_factor=scale_factor) + + return grad_input, grad_masks, None, None, None + + +carafe_naive = CARAFENaiveFunction.apply + + +class CARAFENaive(Module): + + def __init__(self, kernel_size, group_size, scale_factor): + super(CARAFENaive, self).__init__() + + assert isinstance(kernel_size, int) and isinstance( + group_size, int) and isinstance(scale_factor, int) + self.kernel_size = kernel_size + self.group_size = group_size + self.scale_factor = scale_factor + + def forward(self, features, masks): + return carafe_naive(features, masks, self.kernel_size, self.group_size, + self.scale_factor) + + +class CARAFEFunction(Function): + + @staticmethod + def symbolic(g, features, masks, kernel_size, group_size, scale_factor): + return g.op( + 'mmcv::MMCVCARAFE', + features, + masks, + kernel_size_i=kernel_size, + group_size_i=group_size, + scale_factor_f=scale_factor) + + @staticmethod + def forward(ctx, features, masks, kernel_size, group_size, scale_factor): + assert scale_factor >= 1 + assert masks.size(1) == kernel_size * kernel_size * group_size + assert masks.size(-1) == features.size(-1) * scale_factor + assert masks.size(-2) == features.size(-2) * scale_factor + assert features.size(1) % group_size == 0 + assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1 + ctx.kernel_size = kernel_size + ctx.group_size = group_size + ctx.scale_factor = scale_factor + ctx.feature_size = features.size() + ctx.mask_size = masks.size() + + n, c, h, w = features.size() + output = features.new_zeros((n, c, h * scale_factor, w * scale_factor)) + routput = features.new_zeros(output.size(), requires_grad=False) + rfeatures = features.new_zeros(features.size(), requires_grad=False) + rmasks = masks.new_zeros(masks.size(), requires_grad=False) + ext_module.carafe_forward( + features, + masks, + rfeatures, + routput, + rmasks, + output, + kernel_size=kernel_size, + group_size=group_size, + scale_factor=scale_factor) + + if features.requires_grad or masks.requires_grad: + ctx.save_for_backward(features, masks, rfeatures) + return output + + @staticmethod + def backward(ctx, grad_output): + assert grad_output.is_cuda + + features, masks, rfeatures = ctx.saved_tensors + kernel_size = ctx.kernel_size + group_size = ctx.group_size + scale_factor = ctx.scale_factor + + rgrad_output = torch.zeros_like(grad_output, requires_grad=False) + rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False) + rgrad_input = torch.zeros_like(features, requires_grad=False) + rgrad_masks = torch.zeros_like(masks, requires_grad=False) + grad_input = torch.zeros_like(features, requires_grad=False) + grad_masks = torch.zeros_like(masks, requires_grad=False) + ext_module.carafe_backward( + grad_output.contiguous(), + rfeatures, + masks, + rgrad_output, + rgrad_input_hs, + rgrad_input, + rgrad_masks, + grad_input, + grad_masks, + kernel_size=kernel_size, + group_size=group_size, + scale_factor=scale_factor) + return grad_input, grad_masks, None, None, None + + +carafe = CARAFEFunction.apply + + +class CARAFE(Module): + """ CARAFE: Content-Aware ReAssembly of FEatures + + Please refer to https://arxiv.org/abs/1905.02188 for more details. + + Args: + kernel_size (int): reassemble kernel size + group_size (int): reassemble group size + scale_factor (int): upsample ratio + + Returns: + upsampled feature map + """ + + def __init__(self, kernel_size, group_size, scale_factor): + super(CARAFE, self).__init__() + + assert isinstance(kernel_size, int) and isinstance( + group_size, int) and isinstance(scale_factor, int) + self.kernel_size = kernel_size + self.group_size = group_size + self.scale_factor = scale_factor + + def forward(self, features, masks): + return carafe(features, masks, self.kernel_size, self.group_size, + self.scale_factor) + + +@UPSAMPLE_LAYERS.register_module(name='carafe') +class CARAFEPack(nn.Module): + """A unified package of CARAFE upsampler that contains: 1) channel + compressor 2) content encoder 3) CARAFE op. + + Official implementation of ICCV 2019 paper + CARAFE: Content-Aware ReAssembly of FEatures + Please refer to https://arxiv.org/abs/1905.02188 for more details. + + Args: + channels (int): input feature channels + scale_factor (int): upsample ratio + up_kernel (int): kernel size of CARAFE op + up_group (int): group size of CARAFE op + encoder_kernel (int): kernel size of content encoder + encoder_dilation (int): dilation of content encoder + compressed_channels (int): output channels of channels compressor + + Returns: + upsampled feature map + """ + + def __init__(self, + channels, + scale_factor, + up_kernel=5, + up_group=1, + encoder_kernel=3, + encoder_dilation=1, + compressed_channels=64): + super(CARAFEPack, self).__init__() + self.channels = channels + self.scale_factor = scale_factor + self.up_kernel = up_kernel + self.up_group = up_group + self.encoder_kernel = encoder_kernel + self.encoder_dilation = encoder_dilation + self.compressed_channels = compressed_channels + self.channel_compressor = nn.Conv2d(channels, self.compressed_channels, + 1) + self.content_encoder = nn.Conv2d( + self.compressed_channels, + self.up_kernel * self.up_kernel * self.up_group * + self.scale_factor * self.scale_factor, + self.encoder_kernel, + padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2), + dilation=self.encoder_dilation, + groups=1) + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + normal_init(self.content_encoder, std=0.001) + + def kernel_normalizer(self, mask): + mask = F.pixel_shuffle(mask, self.scale_factor) + n, mask_c, h, w = mask.size() + # use float division explicitly, + # to void inconsistency while exporting to onnx + mask_channel = int(mask_c / float(self.up_kernel**2)) + mask = mask.view(n, mask_channel, -1, h, w) + + mask = F.softmax(mask, dim=2, dtype=mask.dtype) + mask = mask.view(n, mask_c, h, w).contiguous() + + return mask + + def feature_reassemble(self, x, mask): + x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor) + return x + + def forward(self, x): + compressed_x = self.channel_compressor(x) + mask = self.content_encoder(compressed_x) + mask = self.kernel_normalizer(mask) + + x = self.feature_reassemble(x, mask) + return x diff --git a/annotator/uniformer/mmcv/ops/cc_attention.py b/annotator/uniformer/mmcv/ops/cc_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9207aa95e6730bd9b3362dee612059a5f0ce1c5e --- /dev/null +++ b/annotator/uniformer/mmcv/ops/cc_attention.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from annotator.uniformer.mmcv.cnn import PLUGIN_LAYERS, Scale + + +def NEG_INF_DIAG(n, device): + """Returns a diagonal matrix of size [n, n]. + + The diagonal are all "-inf". This is for avoiding calculating the + overlapped element in the Criss-Cross twice. + """ + return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0) + + +@PLUGIN_LAYERS.register_module() +class CrissCrossAttention(nn.Module): + """Criss-Cross Attention Module. + + .. note:: + Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch + to a pure PyTorch and equivalent implementation. For more + details, please refer to https://github.com/open-mmlab/mmcv/pull/1201. + + Speed comparison for one forward pass + + - Input size: [2,512,97,97] + - Device: 1 NVIDIA GeForce RTX 2080 Ti + + +-----------------------+---------------+------------+---------------+ + | |PyTorch version|CUDA version|Relative speed | + +=======================+===============+============+===============+ + |with torch.no_grad() |0.00554402 s |0.0299619 s |5.4x | + +-----------------------+---------------+------------+---------------+ + |no with torch.no_grad()|0.00562803 s |0.0301349 s |5.4x | + +-----------------------+---------------+------------+---------------+ + + Args: + in_channels (int): Channels of the input feature map. + """ + + def __init__(self, in_channels): + super().__init__() + self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1) + self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1) + self.value_conv = nn.Conv2d(in_channels, in_channels, 1) + self.gamma = Scale(0.) + self.in_channels = in_channels + + def forward(self, x): + """forward function of Criss-Cross Attention. + + Args: + x (Tensor): Input feature. \ + shape (batch_size, in_channels, height, width) + Returns: + Tensor: Output of the layer, with shape of \ + (batch_size, in_channels, height, width) + """ + B, C, H, W = x.size() + query = self.query_conv(x) + key = self.key_conv(x) + value = self.value_conv(x) + energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG( + H, query.device) + energy_H = energy_H.transpose(1, 2) + energy_W = torch.einsum('bchw,bchj->bhwj', query, key) + attn = F.softmax( + torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)] + out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H]) + out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:]) + + out = self.gamma(out) + x + out = out.contiguous() + + return out + + def __repr__(self): + s = self.__class__.__name__ + s += f'(in_channels={self.in_channels})' + return s diff --git a/annotator/uniformer/mmcv/ops/contour_expand.py b/annotator/uniformer/mmcv/ops/contour_expand.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1111e1768b5f27e118bf7dbc0d9c70a7afd6d7 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/contour_expand.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['contour_expand']) + + +def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area, + kernel_num): + """Expand kernel contours so that foreground pixels are assigned into + instances. + + Arguments: + kernel_mask (np.array or Tensor): The instance kernel mask with + size hxw. + internal_kernel_label (np.array or Tensor): The instance internal + kernel label with size hxw. + min_kernel_area (int): The minimum kernel area. + kernel_num (int): The instance kernel number. + + Returns: + label (list): The instance index map with size hxw. + """ + assert isinstance(kernel_mask, (torch.Tensor, np.ndarray)) + assert isinstance(internal_kernel_label, (torch.Tensor, np.ndarray)) + assert isinstance(min_kernel_area, int) + assert isinstance(kernel_num, int) + + if isinstance(kernel_mask, np.ndarray): + kernel_mask = torch.from_numpy(kernel_mask) + if isinstance(internal_kernel_label, np.ndarray): + internal_kernel_label = torch.from_numpy(internal_kernel_label) + + if torch.__version__ == 'parrots': + if kernel_mask.shape[0] == 0 or internal_kernel_label.shape[0] == 0: + label = [] + else: + label = ext_module.contour_expand( + kernel_mask, + internal_kernel_label, + min_kernel_area=min_kernel_area, + kernel_num=kernel_num) + label = label.tolist() + else: + label = ext_module.contour_expand(kernel_mask, internal_kernel_label, + min_kernel_area, kernel_num) + return label diff --git a/annotator/uniformer/mmcv/ops/corner_pool.py b/annotator/uniformer/mmcv/ops/corner_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..a33d798b43d405e4c86bee4cd6389be21ca9c637 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/corner_pool.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward', + 'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward', + 'right_pool_forward', 'right_pool_backward' +]) + +_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} + + +class TopPoolFunction(Function): + + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top'])) + return output + + @staticmethod + def forward(ctx, input): + output = ext_module.top_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + output = ext_module.top_pool_backward(input, grad_output) + return output + + +class BottomPoolFunction(Function): + + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom'])) + return output + + @staticmethod + def forward(ctx, input): + output = ext_module.bottom_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + output = ext_module.bottom_pool_backward(input, grad_output) + return output + + +class LeftPoolFunction(Function): + + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left'])) + return output + + @staticmethod + def forward(ctx, input): + output = ext_module.left_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + output = ext_module.left_pool_backward(input, grad_output) + return output + + +class RightPoolFunction(Function): + + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right'])) + return output + + @staticmethod + def forward(ctx, input): + output = ext_module.right_pool_forward(input) + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + output = ext_module.right_pool_backward(input, grad_output) + return output + + +class CornerPool(nn.Module): + """Corner Pooling. + + Corner Pooling is a new type of pooling layer that helps a + convolutional network better localize corners of bounding boxes. + + Please refer to https://arxiv.org/abs/1808.01244 for more details. + Code is modified from https://github.com/princeton-vl/CornerNet-Lite. + + Args: + mode(str): Pooling orientation for the pooling layer + + - 'bottom': Bottom Pooling + - 'left': Left Pooling + - 'right': Right Pooling + - 'top': Top Pooling + + Returns: + Feature map after pooling. + """ + + pool_functions = { + 'bottom': BottomPoolFunction, + 'left': LeftPoolFunction, + 'right': RightPoolFunction, + 'top': TopPoolFunction, + } + + cummax_dim_flip = { + 'bottom': (2, False), + 'left': (3, True), + 'right': (3, False), + 'top': (2, True), + } + + def __init__(self, mode): + super(CornerPool, self).__init__() + assert mode in self.pool_functions + self.mode = mode + self.corner_pool = self.pool_functions[mode] + + def forward(self, x): + if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0': + if torch.onnx.is_in_onnx_export(): + assert torch.__version__ >= '1.7.0', \ + 'When `cummax` serves as an intermediate component whose '\ + 'outputs is used as inputs for another modules, it\'s '\ + 'expected that pytorch version must be >= 1.7.0, '\ + 'otherwise Error appears like: `RuntimeError: tuple '\ + 'appears in op that does not forward tuples, unsupported '\ + 'kind: prim::PythonOp`.' + + dim, flip = self.cummax_dim_flip[self.mode] + if flip: + x = x.flip(dim) + pool_tensor, _ = torch.cummax(x, dim=dim) + if flip: + pool_tensor = pool_tensor.flip(dim) + return pool_tensor + else: + return self.corner_pool.apply(x) diff --git a/annotator/uniformer/mmcv/ops/correlation.py b/annotator/uniformer/mmcv/ops/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0b79c301b29915dfaf4d2b1846c59be73127d3 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/correlation.py @@ -0,0 +1,196 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor, nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['correlation_forward', 'correlation_backward']) + + +class CorrelationFunction(Function): + + @staticmethod + def forward(ctx, + input1, + input2, + kernel_size=1, + max_displacement=1, + stride=1, + padding=1, + dilation=1, + dilation_patch=1): + + ctx.save_for_backward(input1, input2) + + kH, kW = ctx.kernel_size = _pair(kernel_size) + patch_size = max_displacement * 2 + 1 + ctx.patch_size = patch_size + dH, dW = ctx.stride = _pair(stride) + padH, padW = ctx.padding = _pair(padding) + dilationH, dilationW = ctx.dilation = _pair(dilation) + dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair( + dilation_patch) + + output_size = CorrelationFunction._output_size(ctx, input1) + + output = input1.new_zeros(output_size) + + ext_module.correlation_forward( + input1, + input2, + output, + kH=kH, + kW=kW, + patchH=patch_size, + patchW=patch_size, + padH=padH, + padW=padW, + dilationH=dilationH, + dilationW=dilationW, + dilation_patchH=dilation_patchH, + dilation_patchW=dilation_patchW, + dH=dH, + dW=dW) + + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input1, input2 = ctx.saved_tensors + + kH, kW = ctx.kernel_size + patch_size = ctx.patch_size + padH, padW = ctx.padding + dilationH, dilationW = ctx.dilation + dilation_patchH, dilation_patchW = ctx.dilation_patch + dH, dW = ctx.stride + grad_input1 = torch.zeros_like(input1) + grad_input2 = torch.zeros_like(input2) + + ext_module.correlation_backward( + grad_output, + input1, + input2, + grad_input1, + grad_input2, + kH=kH, + kW=kW, + patchH=patch_size, + patchW=patch_size, + padH=padH, + padW=padW, + dilationH=dilationH, + dilationW=dilationW, + dilation_patchH=dilation_patchH, + dilation_patchW=dilation_patchW, + dH=dH, + dW=dW) + return grad_input1, grad_input2, None, None, None, None, None, None + + @staticmethod + def _output_size(ctx, input1): + iH, iW = input1.size(2), input1.size(3) + batch_size = input1.size(0) + kH, kW = ctx.kernel_size + patch_size = ctx.patch_size + dH, dW = ctx.stride + padH, padW = ctx.padding + dilationH, dilationW = ctx.dilation + dilatedKH = (kH - 1) * dilationH + 1 + dilatedKW = (kW - 1) * dilationW + 1 + + oH = int((iH + 2 * padH - dilatedKH) / dH + 1) + oW = int((iW + 2 * padW - dilatedKW) / dW + 1) + + output_size = (batch_size, patch_size, patch_size, oH, oW) + return output_size + + +class Correlation(nn.Module): + r"""Correlation operator + + This correlation operator works for optical flow correlation computation. + + There are two batched tensors with shape :math:`(N, C, H, W)`, + and the correlation output's shape is :math:`(N, max\_displacement \times + 2 + 1, max\_displacement * 2 + 1, H_{out}, W_{out})` + + where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times padding - + dilation \times (kernel\_size - 1) - 1} + {stride} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times padding - dilation + \times (kernel\_size - 1) - 1} + {stride} + 1\right\rfloor + + the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding + window convolution between input1 and shifted input2, + + .. math:: + Corr(N_i, dx, dy) = + \sum_{c=0}^{C-1} + input1(N_i, c) \star + \mathcal{S}(input2(N_i, c), dy, dx) + + where :math:`\star` is the valid 2d sliding window convolution operator, + and :math:`\mathcal{S}` means shifting the input features (auto-complete + zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in + [-max\_displacement \times dilation\_patch, max\_displacement \times + dilation\_patch]`. + + Args: + kernel_size (int): The size of sliding window i.e. local neighborhood + representing the center points and involved in correlation + computation. Defaults to 1. + max_displacement (int): The radius for computing correlation volume, + but the actual working space can be dilated by dilation_patch. + Defaults to 1. + stride (int): The stride of the sliding blocks in the input spatial + dimensions. Defaults to 1. + padding (int): Zero padding added to all four sides of the input1. + Defaults to 0. + dilation (int): The spacing of local neighborhood that will involved + in correlation. Defaults to 1. + dilation_patch (int): The spacing between position need to compute + correlation. Defaults to 1. + """ + + def __init__(self, + kernel_size: int = 1, + max_displacement: int = 1, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + dilation_patch: int = 1) -> None: + super().__init__() + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride = stride + self.padding = padding + self.dilation = dilation + self.dilation_patch = dilation_patch + + def forward(self, input1: Tensor, input2: Tensor) -> Tensor: + return CorrelationFunction.apply(input1, input2, self.kernel_size, + self.max_displacement, self.stride, + self.padding, self.dilation, + self.dilation_patch) + + def __repr__(self) -> str: + s = self.__class__.__name__ + s += f'(kernel_size={self.kernel_size}, ' + s += f'max_displacement={self.max_displacement}, ' + s += f'stride={self.stride}, ' + s += f'padding={self.padding}, ' + s += f'dilation={self.dilation}, ' + s += f'dilation_patch={self.dilation_patch})' + return s diff --git a/annotator/uniformer/mmcv/ops/deform_conv.py b/annotator/uniformer/mmcv/ops/deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f8c75ee774823eea334e3b3732af6a18f55038 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/deform_conv.py @@ -0,0 +1,405 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair, _single + +from annotator.uniformer.mmcv.utils import deprecated_api_warning +from ..cnn import CONV_LAYERS +from ..utils import ext_loader, print_log + +ext_module = ext_loader.load_ext('_ext', [ + 'deform_conv_forward', 'deform_conv_backward_input', + 'deform_conv_backward_parameters' +]) + + +class DeformConv2dFunction(Function): + + @staticmethod + def symbolic(g, + input, + offset, + weight, + stride, + padding, + dilation, + groups, + deform_groups, + bias=False, + im2col_step=32): + return g.op( + 'mmcv::MMCVDeformConv2d', + input, + offset, + weight, + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + groups_i=groups, + deform_groups_i=deform_groups, + bias_i=bias, + im2col_step_i=im2col_step) + + @staticmethod + def forward(ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deform_groups=1, + bias=False, + im2col_step=32): + if input is not None and input.dim() != 4: + raise ValueError( + f'Expected 4D tensor as input, got {input.dim()}D tensor \ + instead.') + assert bias is False, 'Only support bias is False.' + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deform_groups = deform_groups + ctx.im2col_step = im2col_step + + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of model (float32), but "offset" is cast + # to float16 by nn.Conv2d automatically, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "offset", + # we cast weight and input to temporarily support fp16 and amp + # whatever the pytorch version is. + input = input.type_as(offset) + weight = weight.type_as(input) + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty( + DeformConv2dFunction._output_size(ctx, input, weight)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + cur_im2col_step = min(ctx.im2col_step, input.size(0)) + assert (input.size(0) % + cur_im2col_step) == 0, 'im2col step must divide batchsize' + ext_module.deform_conv_forward( + input, + weight, + offset, + output, + ctx.bufs_[0], + ctx.bufs_[1], + kW=weight.size(3), + kH=weight.size(2), + dW=ctx.stride[1], + dH=ctx.stride[0], + padW=ctx.padding[1], + padH=ctx.padding[0], + dilationW=ctx.dilation[1], + dilationH=ctx.dilation[0], + group=ctx.groups, + deformable_group=ctx.deform_groups, + im2col_step=cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + cur_im2col_step = min(ctx.im2col_step, input.size(0)) + assert (input.size(0) % cur_im2col_step + ) == 0, 'batch size must be divisible by im2col_step' + + grad_output = grad_output.contiguous() + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + ext_module.deform_conv_backward_input( + input, + offset, + grad_output, + grad_input, + grad_offset, + weight, + ctx.bufs_[0], + kW=weight.size(3), + kH=weight.size(2), + dW=ctx.stride[1], + dH=ctx.stride[0], + padW=ctx.padding[1], + padH=ctx.padding[0], + dilationW=ctx.dilation[1], + dilationH=ctx.dilation[0], + group=ctx.groups, + deformable_group=ctx.deform_groups, + im2col_step=cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + ext_module.deform_conv_backward_parameters( + input, + offset, + grad_output, + grad_weight, + ctx.bufs_[0], + ctx.bufs_[1], + kW=weight.size(3), + kH=weight.size(2), + dW=ctx.stride[1], + dH=ctx.stride[0], + padW=ctx.padding[1], + padH=ctx.padding[0], + dilationW=ctx.dilation[1], + dilationH=ctx.dilation[0], + group=ctx.groups, + deformable_group=ctx.deform_groups, + scale=1, + im2col_step=cur_im2col_step) + + return grad_input, grad_offset, grad_weight, \ + None, None, None, None, None, None, None + + @staticmethod + def _output_size(ctx, input, weight): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = ctx.padding[d] + kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = ctx.stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + 'convolution input is too small (output would be ' + + 'x'.join(map(str, output_size)) + ')') + return output_size + + +deform_conv2d = DeformConv2dFunction.apply + + +class DeformConv2d(nn.Module): + r"""Deformable 2D convolution. + + Applies a deformable 2D convolution over an input signal composed of + several input planes. DeformConv2d was described in the paper + `Deformable Convolutional Networks + `_ + + Note: + The argument ``im2col_step`` was added in version 1.3.17, which means + number of samples processed by the ``im2col_cuda_kernel`` per call. + It enables users to define ``batch_size`` and ``im2col_step`` more + flexibly and solved `issue mmcv#1440 + `_. + + Args: + in_channels (int): Number of channels in the input image. + out_channels (int): Number of channels produced by the convolution. + kernel_size(int, tuple): Size of the convolving kernel. + stride(int, tuple): Stride of the convolution. Default: 1. + padding (int or tuple): Zero-padding added to both sides of the input. + Default: 0. + dilation (int or tuple): Spacing between kernel elements. Default: 1. + groups (int): Number of blocked connections from input. + channels to output channels. Default: 1. + deform_groups (int): Number of deformable group partitions. + bias (bool): If True, adds a learnable bias to the output. + Default: False. + im2col_step (int): Number of samples processed by im2col_cuda_kernel + per call. It will work when ``batch_size`` > ``im2col_step``, but + ``batch_size`` must be divisible by ``im2col_step``. Default: 32. + `New in version 1.3.17.` + """ + + @deprecated_api_warning({'deformable_groups': 'deform_groups'}, + cls_name='DeformConv2d') + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + deform_groups: int = 1, + bias: bool = False, + im2col_step: int = 32) -> None: + super(DeformConv2d, self).__init__() + + assert not bias, \ + f'bias={bias} is not supported in DeformConv2d.' + assert in_channels % groups == 0, \ + f'in_channels {in_channels} cannot be divisible by groups {groups}' + assert out_channels % groups == 0, \ + f'out_channels {out_channels} cannot be divisible by groups \ + {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deform_groups = deform_groups + self.im2col_step = im2col_step + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + # only weight, no bias + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // self.groups, + *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + # switch the initialization of `self.weight` to the standard kaiming + # method described in `Delving deep into rectifiers: Surpassing + # human-level performance on ImageNet classification` - He, K. et al. + # (2015), using a uniform distribution + nn.init.kaiming_uniform_(self.weight, nonlinearity='relu') + + def forward(self, x: Tensor, offset: Tensor) -> Tensor: + """Deformable Convolutional forward function. + + Args: + x (Tensor): Input feature, shape (B, C_in, H_in, W_in) + offset (Tensor): Offset for deformable convolution, shape + (B, deform_groups*kernel_size[0]*kernel_size[1]*2, + H_out, W_out), H_out, W_out are equal to the output's. + + An offset is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`. + The spatial arrangement is like: + + .. code:: text + + (x0, y0) (x1, y1) (x2, y2) + (x3, y3) (x4, y4) (x5, y5) + (x6, y6) (x7, y7) (x8, y8) + + Returns: + Tensor: Output of the layer. + """ + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) < + self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0) + offset = offset.contiguous() + out = deform_conv2d(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deform_groups, + False, self.im2col_step) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - + pad_w].contiguous() + return out + + def __repr__(self): + s = self.__class__.__name__ + s += f'(in_channels={self.in_channels},\n' + s += f'out_channels={self.out_channels},\n' + s += f'kernel_size={self.kernel_size},\n' + s += f'stride={self.stride},\n' + s += f'padding={self.padding},\n' + s += f'dilation={self.dilation},\n' + s += f'groups={self.groups},\n' + s += f'deform_groups={self.deform_groups},\n' + # bias is not supported in DeformConv2d. + s += 'bias=False)' + return s + + +@CONV_LAYERS.register_module('DCN') +class DeformConv2dPack(DeformConv2d): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + The offset tensor is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`. + The spatial arrangement is like: + + .. code:: text + + (x0, y0) (x1, y1) (x2, y2) + (x3, y3) (x4, y4) (x5, y5) + (x6, y6) (x7, y7) (x8, y8) + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConv2dPack, self).__init__(*args, **kwargs) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv2d(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deform_groups, + False, self.im2col_step) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + + if version is None or version < 2: + # the key is different in early versions + # In version < 2, DeformConvPack loads previous benchmark models. + if (prefix + 'conv_offset.weight' not in state_dict + and prefix[:-1] + '_offset.weight' in state_dict): + state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( + prefix[:-1] + '_offset.weight') + if (prefix + 'conv_offset.bias' not in state_dict + and prefix[:-1] + '_offset.bias' in state_dict): + state_dict[prefix + + 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + + '_offset.bias') + + if version is not None and version > 1: + print_log( + f'DeformConv2dPack {prefix.rstrip(".")} is upgraded to ' + 'version 2.', + logger='root') + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) diff --git a/annotator/uniformer/mmcv/ops/deform_roi_pool.py b/annotator/uniformer/mmcv/ops/deform_roi_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..cc245ba91fee252226ba22e76bb94a35db9a629b --- /dev/null +++ b/annotator/uniformer/mmcv/ops/deform_roi_pool.py @@ -0,0 +1,204 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['deform_roi_pool_forward', 'deform_roi_pool_backward']) + + +class DeformRoIPoolFunction(Function): + + @staticmethod + def symbolic(g, input, rois, offset, output_size, spatial_scale, + sampling_ratio, gamma): + return g.op( + 'mmcv::MMCVDeformRoIPool', + input, + rois, + offset, + pooled_height_i=output_size[0], + pooled_width_i=output_size[1], + spatial_scale_f=spatial_scale, + sampling_ratio_f=sampling_ratio, + gamma_f=gamma) + + @staticmethod + def forward(ctx, + input, + rois, + offset, + output_size, + spatial_scale=1.0, + sampling_ratio=0, + gamma=0.1): + if offset is None: + offset = input.new_zeros(0) + ctx.output_size = _pair(output_size) + ctx.spatial_scale = float(spatial_scale) + ctx.sampling_ratio = int(sampling_ratio) + ctx.gamma = float(gamma) + + assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!' + + output_shape = (rois.size(0), input.size(1), ctx.output_size[0], + ctx.output_size[1]) + output = input.new_zeros(output_shape) + + ext_module.deform_roi_pool_forward( + input, + rois, + offset, + output, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + gamma=ctx.gamma) + + ctx.save_for_backward(input, rois, offset) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, rois, offset = ctx.saved_tensors + grad_input = grad_output.new_zeros(input.shape) + grad_offset = grad_output.new_zeros(offset.shape) + + ext_module.deform_roi_pool_backward( + grad_output, + input, + rois, + offset, + grad_input, + grad_offset, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + gamma=ctx.gamma) + if grad_offset.numel() == 0: + grad_offset = None + return grad_input, None, grad_offset, None, None, None, None + + +deform_roi_pool = DeformRoIPoolFunction.apply + + +class DeformRoIPool(nn.Module): + + def __init__(self, + output_size, + spatial_scale=1.0, + sampling_ratio=0, + gamma=0.1): + super(DeformRoIPool, self).__init__() + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + self.sampling_ratio = int(sampling_ratio) + self.gamma = float(gamma) + + def forward(self, input, rois, offset=None): + return deform_roi_pool(input, rois, offset, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + + +class DeformRoIPoolPack(DeformRoIPool): + + def __init__(self, + output_size, + output_channels, + deform_fc_channels=1024, + spatial_scale=1.0, + sampling_ratio=0, + gamma=0.1): + super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale, + sampling_ratio, gamma) + + self.output_channels = output_channels + self.deform_fc_channels = deform_fc_channels + + self.offset_fc = nn.Sequential( + nn.Linear( + self.output_size[0] * self.output_size[1] * + self.output_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.output_size[0] * self.output_size[1] * 2)) + self.offset_fc[-1].weight.data.zero_() + self.offset_fc[-1].bias.data.zero_() + + def forward(self, input, rois): + assert input.size(1) == self.output_channels + x = deform_roi_pool(input, rois, None, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + rois_num = rois.size(0) + offset = self.offset_fc(x.view(rois_num, -1)) + offset = offset.view(rois_num, 2, self.output_size[0], + self.output_size[1]) + return deform_roi_pool(input, rois, offset, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + + +class ModulatedDeformRoIPoolPack(DeformRoIPool): + + def __init__(self, + output_size, + output_channels, + deform_fc_channels=1024, + spatial_scale=1.0, + sampling_ratio=0, + gamma=0.1): + super(ModulatedDeformRoIPoolPack, + self).__init__(output_size, spatial_scale, sampling_ratio, gamma) + + self.output_channels = output_channels + self.deform_fc_channels = deform_fc_channels + + self.offset_fc = nn.Sequential( + nn.Linear( + self.output_size[0] * self.output_size[1] * + self.output_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.output_size[0] * self.output_size[1] * 2)) + self.offset_fc[-1].weight.data.zero_() + self.offset_fc[-1].bias.data.zero_() + + self.mask_fc = nn.Sequential( + nn.Linear( + self.output_size[0] * self.output_size[1] * + self.output_channels, self.deform_fc_channels), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_channels, + self.output_size[0] * self.output_size[1] * 1), + nn.Sigmoid()) + self.mask_fc[2].weight.data.zero_() + self.mask_fc[2].bias.data.zero_() + + def forward(self, input, rois): + assert input.size(1) == self.output_channels + x = deform_roi_pool(input, rois, None, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + rois_num = rois.size(0) + offset = self.offset_fc(x.view(rois_num, -1)) + offset = offset.view(rois_num, 2, self.output_size[0], + self.output_size[1]) + mask = self.mask_fc(x.view(rois_num, -1)) + mask = mask.view(rois_num, 1, self.output_size[0], self.output_size[1]) + d = deform_roi_pool(input, rois, offset, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.gamma) + return d * mask diff --git a/annotator/uniformer/mmcv/ops/deprecated_wrappers.py b/annotator/uniformer/mmcv/ops/deprecated_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e593df9ee57637038683d7a1efaa347b2b69e7 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/deprecated_wrappers.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This file is for backward compatibility. +# Module wrappers for empty tensor have been moved to mmcv.cnn.bricks. +import warnings + +from ..cnn.bricks.wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d + + +class Conv2d_deprecated(Conv2d): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in' + ' the future. Please import them from "mmcv.cnn" instead') + + +class ConvTranspose2d_deprecated(ConvTranspose2d): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing ConvTranspose2d wrapper from "mmcv.ops" will be ' + 'deprecated in the future. Please import them from "mmcv.cnn" ' + 'instead') + + +class MaxPool2d_deprecated(MaxPool2d): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in' + ' the future. Please import them from "mmcv.cnn" instead') + + +class Linear_deprecated(Linear): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + 'Importing Linear wrapper from "mmcv.ops" will be deprecated in' + ' the future. Please import them from "mmcv.cnn" instead') diff --git a/annotator/uniformer/mmcv/ops/focal_loss.py b/annotator/uniformer/mmcv/ops/focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..763bc93bd2575c49ca8ccf20996bbd92d1e0d1a4 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/focal_loss.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward', + 'softmax_focal_loss_forward', 'softmax_focal_loss_backward' +]) + + +class SigmoidFocalLossFunction(Function): + + @staticmethod + def symbolic(g, input, target, gamma, alpha, weight, reduction): + return g.op( + 'mmcv::MMCVSigmoidFocalLoss', + input, + target, + gamma_f=gamma, + alpha_f=alpha, + weight_f=weight, + reduction_s=reduction) + + @staticmethod + def forward(ctx, + input, + target, + gamma=2.0, + alpha=0.25, + weight=None, + reduction='mean'): + + assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) + assert input.dim() == 2 + assert target.dim() == 1 + assert input.size(0) == target.size(0) + if weight is None: + weight = input.new_empty(0) + else: + assert weight.dim() == 1 + assert input.size(1) == weight.size(0) + ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} + assert reduction in ctx.reduction_dict.keys() + + ctx.gamma = float(gamma) + ctx.alpha = float(alpha) + ctx.reduction = ctx.reduction_dict[reduction] + + output = input.new_zeros(input.size()) + + ext_module.sigmoid_focal_loss_forward( + input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha) + if ctx.reduction == ctx.reduction_dict['mean']: + output = output.sum() / input.size(0) + elif ctx.reduction == ctx.reduction_dict['sum']: + output = output.sum() + ctx.save_for_backward(input, target, weight) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, target, weight = ctx.saved_tensors + + grad_input = input.new_zeros(input.size()) + + ext_module.sigmoid_focal_loss_backward( + input, + target, + weight, + grad_input, + gamma=ctx.gamma, + alpha=ctx.alpha) + + grad_input *= grad_output + if ctx.reduction == ctx.reduction_dict['mean']: + grad_input /= input.size(0) + return grad_input, None, None, None, None, None + + +sigmoid_focal_loss = SigmoidFocalLossFunction.apply + + +class SigmoidFocalLoss(nn.Module): + + def __init__(self, gamma, alpha, weight=None, reduction='mean'): + super(SigmoidFocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.register_buffer('weight', weight) + self.reduction = reduction + + def forward(self, input, target): + return sigmoid_focal_loss(input, target, self.gamma, self.alpha, + self.weight, self.reduction) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(gamma={self.gamma}, ' + s += f'alpha={self.alpha}, ' + s += f'reduction={self.reduction})' + return s + + +class SoftmaxFocalLossFunction(Function): + + @staticmethod + def symbolic(g, input, target, gamma, alpha, weight, reduction): + return g.op( + 'mmcv::MMCVSoftmaxFocalLoss', + input, + target, + gamma_f=gamma, + alpha_f=alpha, + weight_f=weight, + reduction_s=reduction) + + @staticmethod + def forward(ctx, + input, + target, + gamma=2.0, + alpha=0.25, + weight=None, + reduction='mean'): + + assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) + assert input.dim() == 2 + assert target.dim() == 1 + assert input.size(0) == target.size(0) + if weight is None: + weight = input.new_empty(0) + else: + assert weight.dim() == 1 + assert input.size(1) == weight.size(0) + ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} + assert reduction in ctx.reduction_dict.keys() + + ctx.gamma = float(gamma) + ctx.alpha = float(alpha) + ctx.reduction = ctx.reduction_dict[reduction] + + channel_stats, _ = torch.max(input, dim=1) + input_softmax = input - channel_stats.unsqueeze(1).expand_as(input) + input_softmax.exp_() + + channel_stats = input_softmax.sum(dim=1) + input_softmax /= channel_stats.unsqueeze(1).expand_as(input) + + output = input.new_zeros(input.size(0)) + ext_module.softmax_focal_loss_forward( + input_softmax, + target, + weight, + output, + gamma=ctx.gamma, + alpha=ctx.alpha) + + if ctx.reduction == ctx.reduction_dict['mean']: + output = output.sum() / input.size(0) + elif ctx.reduction == ctx.reduction_dict['sum']: + output = output.sum() + ctx.save_for_backward(input_softmax, target, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input_softmax, target, weight = ctx.saved_tensors + buff = input_softmax.new_zeros(input_softmax.size(0)) + grad_input = input_softmax.new_zeros(input_softmax.size()) + + ext_module.softmax_focal_loss_backward( + input_softmax, + target, + weight, + buff, + grad_input, + gamma=ctx.gamma, + alpha=ctx.alpha) + + grad_input *= grad_output + if ctx.reduction == ctx.reduction_dict['mean']: + grad_input /= input_softmax.size(0) + return grad_input, None, None, None, None, None + + +softmax_focal_loss = SoftmaxFocalLossFunction.apply + + +class SoftmaxFocalLoss(nn.Module): + + def __init__(self, gamma, alpha, weight=None, reduction='mean'): + super(SoftmaxFocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.register_buffer('weight', weight) + self.reduction = reduction + + def forward(self, input, target): + return softmax_focal_loss(input, target, self.gamma, self.alpha, + self.weight, self.reduction) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(gamma={self.gamma}, ' + s += f'alpha={self.alpha}, ' + s += f'reduction={self.reduction})' + return s diff --git a/annotator/uniformer/mmcv/ops/furthest_point_sample.py b/annotator/uniformer/mmcv/ops/furthest_point_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..374b7a878f1972c183941af28ba1df216ac1a60f --- /dev/null +++ b/annotator/uniformer/mmcv/ops/furthest_point_sample.py @@ -0,0 +1,83 @@ +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'furthest_point_sampling_forward', + 'furthest_point_sampling_with_dist_forward' +]) + + +class FurthestPointSampling(Function): + """Uses iterative furthest point sampling to select a set of features whose + corresponding points have the furthest distance.""" + + @staticmethod + def forward(ctx, points_xyz: torch.Tensor, + num_points: int) -> torch.Tensor: + """ + Args: + points_xyz (Tensor): (B, N, 3) where N > num_points. + num_points (int): Number of points in the sampled set. + + Returns: + Tensor: (B, num_points) indices of the sampled points. + """ + assert points_xyz.is_contiguous() + + B, N = points_xyz.size()[:2] + output = torch.cuda.IntTensor(B, num_points) + temp = torch.cuda.FloatTensor(B, N).fill_(1e10) + + ext_module.furthest_point_sampling_forward( + points_xyz, + temp, + output, + b=B, + n=N, + m=num_points, + ) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(output) + return output + + @staticmethod + def backward(xyz, a=None): + return None, None + + +class FurthestPointSamplingWithDist(Function): + """Uses iterative furthest point sampling to select a set of features whose + corresponding points have the furthest distance.""" + + @staticmethod + def forward(ctx, points_dist: torch.Tensor, + num_points: int) -> torch.Tensor: + """ + Args: + points_dist (Tensor): (B, N, N) Distance between each point pair. + num_points (int): Number of points in the sampled set. + + Returns: + Tensor: (B, num_points) indices of the sampled points. + """ + assert points_dist.is_contiguous() + + B, N, _ = points_dist.size() + output = points_dist.new_zeros([B, num_points], dtype=torch.int32) + temp = points_dist.new_zeros([B, N]).fill_(1e10) + + ext_module.furthest_point_sampling_with_dist_forward( + points_dist, temp, output, b=B, n=N, m=num_points) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(output) + return output + + @staticmethod + def backward(xyz, a=None): + return None, None + + +furthest_point_sample = FurthestPointSampling.apply +furthest_point_sample_with_dist = FurthestPointSamplingWithDist.apply diff --git a/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py b/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py new file mode 100644 index 0000000000000000000000000000000000000000..6d12508469c6c8fa1884debece44c58d158cb6fa --- /dev/null +++ b/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py @@ -0,0 +1,268 @@ +# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator +# Augmentation (ADA) +# ======================================================================= + +# 1. Definitions + +# "Licensor" means any person or entity that distributes its Work. + +# "Software" means the original work of authorship made available under +# this License. + +# "Work" means the Software and any additions to or derivative works of +# the Software that are made available under this License. + +# The terms "reproduce," "reproduction," "derivative works," and +# "distribution" have the meaning as provided under U.S. copyright law; +# provided, however, that for the purposes of this License, derivative +# works shall not include works that remain separable from, or merely +# link (or bind by name) to the interfaces of, the Work. + +# Works, including the Software, are "made available" under this License +# by including in or with the Work either (a) a copyright notice +# referencing the applicability of this License to the Work, or (b) a +# copy of this License. + +# 2. License Grants + +# 2.1 Copyright Grant. Subject to the terms and conditions of this +# License, each Licensor grants to you a perpetual, worldwide, +# non-exclusive, royalty-free, copyright license to reproduce, +# prepare derivative works of, publicly display, publicly perform, +# sublicense and distribute its Work and any resulting derivative +# works in any form. + +# 3. Limitations + +# 3.1 Redistribution. You may reproduce or distribute the Work only +# if (a) you do so under this License, (b) you include a complete +# copy of this License with your distribution, and (c) you retain +# without modification any copyright, patent, trademark, or +# attribution notices that are present in the Work. + +# 3.2 Derivative Works. You may specify that additional or different +# terms apply to the use, reproduction, and distribution of your +# derivative works of the Work ("Your Terms") only if (a) Your Terms +# provide that the use limitation in Section 3.3 applies to your +# derivative works, and (b) you identify the specific derivative +# works that are subject to Your Terms. Notwithstanding Your Terms, +# this License (including the redistribution requirements in Section +# 3.1) will continue to apply to the Work itself. + +# 3.3 Use Limitation. The Work and any derivative works thereof only +# may be used or intended for use non-commercially. Notwithstanding +# the foregoing, NVIDIA and its affiliates may use the Work and any +# derivative works commercially. As used herein, "non-commercially" +# means for research or evaluation purposes only. + +# 3.4 Patent Claims. If you bring or threaten to bring a patent claim +# against any Licensor (including any claim, cross-claim or +# counterclaim in a lawsuit) to enforce any patents that you allege +# are infringed by any Work, then your rights under this License from +# such Licensor (including the grant in Section 2.1) will terminate +# immediately. + +# 3.5 Trademarks. This License does not grant any rights to use any +# Licensor’s or its affiliates’ names, logos, or trademarks, except +# as necessary to reproduce the notices described in this License. + +# 3.6 Termination. If you violate any term of this License, then your +# rights under this License (including the grant in Section 2.1) will +# terminate immediately. + +# 4. Disclaimer of Warranty. + +# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +# THIS LICENSE. + +# 5. Limitation of Liability. + +# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGES. + +# ======================================================================= + +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['fused_bias_leakyrelu']) + + +class FusedBiasLeakyReLUFunctionBackward(Function): + """Calculate second order deviation. + + This function is to compute the second order deviation for the fused leaky + relu operation. + """ + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = ext_module.fused_bias_leakyrelu( + grad_output, + empty, + out, + act=3, + grad=1, + alpha=negative_slope, + scale=scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + + # The second order deviation, in fact, contains two parts, while the + # the first part is zero. Thus, we direct consider the second part + # which is similar with the first order deviation in implementation. + gradgrad_out = ext_module.fused_bias_leakyrelu( + gradgrad_input, + gradgrad_bias.to(out.dtype), + out, + act=3, + grad=1, + alpha=ctx.negative_slope, + scale=ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedBiasLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + + out = ext_module.fused_bias_leakyrelu( + input, + bias, + empty, + act=3, + grad=0, + alpha=negative_slope, + scale=scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedBiasLeakyReLU(nn.Module): + """Fused bias leaky ReLU. + + This function is introduced in the StyleGAN2: + http://arxiv.org/abs/1912.04958 + + The bias term comes from the convolution operation. In addition, to keep + the variance of the feature map or gradients unchanged, they also adopt a + scale similarly with Kaiming initialization. However, since the + :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the + final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501 + your own scale. + + TODO: Implement the CPU version. + + Args: + channel (int): The channel number of the feature map. + negative_slope (float, optional): Same as nn.LeakyRelu. + Defaults to 0.2. + scale (float, optional): A scalar to adjust the variance of the feature + map. Defaults to 2**0.5. + """ + + def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5): + super(FusedBiasLeakyReLU, self).__init__() + + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_bias_leakyrelu(input, self.bias, self.negative_slope, + self.scale) + + +def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5): + """Fused bias leaky ReLU function. + + This function is introduced in the StyleGAN2: + http://arxiv.org/abs/1912.04958 + + The bias term comes from the convolution operation. In addition, to keep + the variance of the feature map or gradients unchanged, they also adopt a + scale similarly with Kaiming initialization. However, since the + :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the + final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501 + your own scale. + + Args: + input (torch.Tensor): Input feature map. + bias (nn.Parameter): The bias from convolution operation. + negative_slope (float, optional): Same as nn.LeakyRelu. + Defaults to 0.2. + scale (float, optional): A scalar to adjust the variance of the feature + map. Defaults to 2**0.5. + + Returns: + torch.Tensor: Feature map after non-linear activation. + """ + + if not input.is_cuda: + return bias_leakyrelu_ref(input, bias, negative_slope, scale) + + return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype), + negative_slope, scale) + + +def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5): + + if bias is not None: + assert bias.ndim == 1 + assert bias.shape[0] == x.shape[1] + x = x + bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)]) + + x = F.leaky_relu(x, negative_slope) + if scale != 1: + x = x * scale + + return x diff --git a/annotator/uniformer/mmcv/ops/gather_points.py b/annotator/uniformer/mmcv/ops/gather_points.py new file mode 100644 index 0000000000000000000000000000000000000000..f52f1677d8ea0facafc56a3672d37adb44677ff3 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/gather_points.py @@ -0,0 +1,57 @@ +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['gather_points_forward', 'gather_points_backward']) + + +class GatherPoints(Function): + """Gather points with given index.""" + + @staticmethod + def forward(ctx, features: torch.Tensor, + indices: torch.Tensor) -> torch.Tensor: + """ + Args: + features (Tensor): (B, C, N) features to gather. + indices (Tensor): (B, M) where M is the number of points. + + Returns: + Tensor: (B, C, M) where M is the number of points. + """ + assert features.is_contiguous() + assert indices.is_contiguous() + + B, npoint = indices.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, npoint) + + ext_module.gather_points_forward( + features, indices, output, b=B, c=C, n=N, npoints=npoint) + + ctx.for_backwards = (indices, C, N) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(indices) + return output + + @staticmethod + def backward(ctx, grad_out): + idx, C, N = ctx.for_backwards + B, npoint = idx.size() + + grad_features = torch.cuda.FloatTensor(B, C, N).zero_() + grad_out_data = grad_out.data.contiguous() + ext_module.gather_points_backward( + grad_out_data, + idx, + grad_features.data, + b=B, + c=C, + n=N, + npoints=npoint) + return grad_features, None + + +gather_points = GatherPoints.apply diff --git a/annotator/uniformer/mmcv/ops/group_points.py b/annotator/uniformer/mmcv/ops/group_points.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3ec9d758ebe4e1c2205882af4be154008253a5 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/group_points.py @@ -0,0 +1,224 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from torch import nn as nn +from torch.autograd import Function + +from ..utils import ext_loader +from .ball_query import ball_query +from .knn import knn + +ext_module = ext_loader.load_ext( + '_ext', ['group_points_forward', 'group_points_backward']) + + +class QueryAndGroup(nn.Module): + """Groups points with a ball query of radius. + + Args: + max_radius (float): The maximum radius of the balls. + If None is given, we will use kNN sampling instead of ball query. + sample_num (int): Maximum number of features to gather in the ball. + min_radius (float, optional): The minimum radius of the balls. + Default: 0. + use_xyz (bool, optional): Whether to use xyz. + Default: True. + return_grouped_xyz (bool, optional): Whether to return grouped xyz. + Default: False. + normalize_xyz (bool, optional): Whether to normalize xyz. + Default: False. + uniform_sample (bool, optional): Whether to sample uniformly. + Default: False + return_unique_cnt (bool, optional): Whether to return the count of + unique samples. Default: False. + return_grouped_idx (bool, optional): Whether to return grouped idx. + Default: False. + """ + + def __init__(self, + max_radius, + sample_num, + min_radius=0, + use_xyz=True, + return_grouped_xyz=False, + normalize_xyz=False, + uniform_sample=False, + return_unique_cnt=False, + return_grouped_idx=False): + super().__init__() + self.max_radius = max_radius + self.min_radius = min_radius + self.sample_num = sample_num + self.use_xyz = use_xyz + self.return_grouped_xyz = return_grouped_xyz + self.normalize_xyz = normalize_xyz + self.uniform_sample = uniform_sample + self.return_unique_cnt = return_unique_cnt + self.return_grouped_idx = return_grouped_idx + if self.return_unique_cnt: + assert self.uniform_sample, \ + 'uniform_sample should be True when ' \ + 'returning the count of unique samples' + if self.max_radius is None: + assert not self.normalize_xyz, \ + 'can not normalize grouped xyz when max_radius is None' + + def forward(self, points_xyz, center_xyz, features=None): + """ + Args: + points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. + center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods. + features (Tensor): (B, C, N) Descriptors of the features. + + Returns: + Tensor: (B, 3 + C, npoint, sample_num) Grouped feature. + """ + # if self.max_radius is None, we will perform kNN instead of ball query + # idx is of shape [B, npoint, sample_num] + if self.max_radius is None: + idx = knn(self.sample_num, points_xyz, center_xyz, False) + idx = idx.transpose(1, 2).contiguous() + else: + idx = ball_query(self.min_radius, self.max_radius, self.sample_num, + points_xyz, center_xyz) + + if self.uniform_sample: + unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) + for i_batch in range(idx.shape[0]): + for i_region in range(idx.shape[1]): + unique_ind = torch.unique(idx[i_batch, i_region, :]) + num_unique = unique_ind.shape[0] + unique_cnt[i_batch, i_region] = num_unique + sample_ind = torch.randint( + 0, + num_unique, (self.sample_num - num_unique, ), + dtype=torch.long) + all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) + idx[i_batch, i_region, :] = all_ind + + xyz_trans = points_xyz.transpose(1, 2).contiguous() + # (B, 3, npoint, sample_num) + grouped_xyz = grouping_operation(xyz_trans, idx) + grouped_xyz_diff = grouped_xyz - \ + center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets + if self.normalize_xyz: + grouped_xyz_diff /= self.max_radius + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + # (B, C + 3, npoint, sample_num) + new_features = torch.cat([grouped_xyz_diff, grouped_features], + dim=1) + else: + new_features = grouped_features + else: + assert (self.use_xyz + ), 'Cannot have not features and not use xyz as a feature!' + new_features = grouped_xyz_diff + + ret = [new_features] + if self.return_grouped_xyz: + ret.append(grouped_xyz) + if self.return_unique_cnt: + ret.append(unique_cnt) + if self.return_grouped_idx: + ret.append(idx) + if len(ret) == 1: + return ret[0] + else: + return tuple(ret) + + +class GroupAll(nn.Module): + """Group xyz with feature. + + Args: + use_xyz (bool): Whether to use xyz. + """ + + def __init__(self, use_xyz: bool = True): + super().__init__() + self.use_xyz = use_xyz + + def forward(self, + xyz: torch.Tensor, + new_xyz: torch.Tensor, + features: torch.Tensor = None): + """ + Args: + xyz (Tensor): (B, N, 3) xyz coordinates of the features. + new_xyz (Tensor): new xyz coordinates of the features. + features (Tensor): (B, C, N) features to group. + + Returns: + Tensor: (B, C + 3, 1, N) Grouped feature. + """ + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + # (B, 3 + C, 1, N) + new_features = torch.cat([grouped_xyz, grouped_features], + dim=1) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features + + +class GroupingOperation(Function): + """Group feature with given index.""" + + @staticmethod + def forward(ctx, features: torch.Tensor, + indices: torch.Tensor) -> torch.Tensor: + """ + Args: + features (Tensor): (B, C, N) tensor of features to group. + indices (Tensor): (B, npoint, nsample) the indices of + features to group with. + + Returns: + Tensor: (B, C, npoint, nsample) Grouped features. + """ + features = features.contiguous() + indices = indices.contiguous() + + B, nfeatures, nsample = indices.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + + ext_module.group_points_forward(B, C, N, nfeatures, nsample, features, + indices, output) + + ctx.for_backwards = (indices, N) + return output + + @staticmethod + def backward(ctx, + grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients + of the output from forward. + + Returns: + Tensor: (B, C, N) gradient of the features. + """ + idx, N = ctx.for_backwards + + B, C, npoint, nsample = grad_out.size() + grad_features = torch.cuda.FloatTensor(B, C, N).zero_() + + grad_out_data = grad_out.data.contiguous() + ext_module.group_points_backward(B, C, N, npoint, nsample, + grad_out_data, idx, + grad_features.data) + return grad_features, None + + +grouping_operation = GroupingOperation.apply diff --git a/annotator/uniformer/mmcv/ops/info.py b/annotator/uniformer/mmcv/ops/info.py new file mode 100644 index 0000000000000000000000000000000000000000..29f2e5598ae2bb5866ccd15a7d3b4de33c0cd14d --- /dev/null +++ b/annotator/uniformer/mmcv/ops/info.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import glob +import os + +import torch + +if torch.__version__ == 'parrots': + import parrots + + def get_compiler_version(): + return 'GCC ' + parrots.version.compiler + + def get_compiling_cuda_version(): + return parrots.version.cuda +else: + from ..utils import ext_loader + ext_module = ext_loader.load_ext( + '_ext', ['get_compiler_version', 'get_compiling_cuda_version']) + + def get_compiler_version(): + return ext_module.get_compiler_version() + + def get_compiling_cuda_version(): + return ext_module.get_compiling_cuda_version() + + +def get_onnxruntime_op_path(): + wildcard = os.path.join( + os.path.abspath(os.path.dirname(os.path.dirname(__file__))), + '_ext_ort.*.so') + + paths = glob.glob(wildcard) + if len(paths) > 0: + return paths[0] + else: + return '' diff --git a/annotator/uniformer/mmcv/ops/iou3d.py b/annotator/uniformer/mmcv/ops/iou3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc71979190323f44c09f8b7e1761cf49cd2d76b --- /dev/null +++ b/annotator/uniformer/mmcv/ops/iou3d.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward', + 'iou3d_nms_normal_forward' +]) + + +def boxes_iou_bev(boxes_a, boxes_b): + """Calculate boxes IoU in the Bird's Eye View. + + Args: + boxes_a (torch.Tensor): Input boxes a with shape (M, 5). + boxes_b (torch.Tensor): Input boxes b with shape (N, 5). + + Returns: + ans_iou (torch.Tensor): IoU result with shape (M, N). + """ + ans_iou = boxes_a.new_zeros( + torch.Size((boxes_a.shape[0], boxes_b.shape[0]))) + + ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(), + boxes_b.contiguous(), ans_iou) + + return ans_iou + + +def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): + """NMS function GPU implementation (for BEV boxes). The overlap of two + boxes for IoU calculation is defined as the exact overlapping area of the + two boxes. In this function, one can also set ``pre_max_size`` and + ``post_max_size``. + + Args: + boxes (torch.Tensor): Input boxes with the shape of [N, 5] + ([x1, y1, x2, y2, ry]). + scores (torch.Tensor): Scores of boxes with the shape of [N]. + thresh (float): Overlap threshold of NMS. + pre_max_size (int, optional): Max size of boxes before NMS. + Default: None. + post_max_size (int, optional): Max size of boxes after NMS. + Default: None. + + Returns: + torch.Tensor: Indexes after NMS. + """ + assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]' + order = scores.sort(0, descending=True)[1] + + if pre_max_size is not None: + order = order[:pre_max_size] + boxes = boxes[order].contiguous() + + keep = torch.zeros(boxes.size(0), dtype=torch.long) + num_out = ext_module.iou3d_nms_forward(boxes, keep, thresh) + keep = order[keep[:num_out].cuda(boxes.device)].contiguous() + if post_max_size is not None: + keep = keep[:post_max_size] + return keep + + +def nms_normal_bev(boxes, scores, thresh): + """Normal NMS function GPU implementation (for BEV boxes). The overlap of + two boxes for IoU calculation is defined as the exact overlapping area of + the two boxes WITH their yaw angle set to 0. + + Args: + boxes (torch.Tensor): Input boxes with shape (N, 5). + scores (torch.Tensor): Scores of predicted boxes with shape (N). + thresh (float): Overlap threshold of NMS. + + Returns: + torch.Tensor: Remaining indices with scores in descending order. + """ + assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]' + order = scores.sort(0, descending=True)[1] + + boxes = boxes[order].contiguous() + + keep = torch.zeros(boxes.size(0), dtype=torch.long) + num_out = ext_module.iou3d_nms_normal_forward(boxes, keep, thresh) + return order[keep[:num_out].cuda(boxes.device)].contiguous() diff --git a/annotator/uniformer/mmcv/ops/knn.py b/annotator/uniformer/mmcv/ops/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..f335785036669fc19239825b0aae6dde3f73bf92 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/knn.py @@ -0,0 +1,77 @@ +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['knn_forward']) + + +class KNN(Function): + r"""KNN (CUDA) based on heap data structure. + Modified from `PAConv `_. + + Find k-nearest points. + """ + + @staticmethod + def forward(ctx, + k: int, + xyz: torch.Tensor, + center_xyz: torch.Tensor = None, + transposed: bool = False) -> torch.Tensor: + """ + Args: + k (int): number of nearest neighbors. + xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N). + xyz coordinates of the features. + center_xyz (Tensor, optional): (B, npoint, 3) if transposed == + False, else (B, 3, npoint). centers of the knn query. + Default: None. + transposed (bool, optional): whether the input tensors are + transposed. Should not explicitly use this keyword when + calling knn (=KNN.apply), just add the fourth param. + Default: False. + + Returns: + Tensor: (B, k, npoint) tensor with the indices of + the features that form k-nearest neighbours. + """ + assert (k > 0) & (k < 100), 'k should be in range(0, 100)' + + if center_xyz is None: + center_xyz = xyz + + if transposed: + xyz = xyz.transpose(2, 1).contiguous() + center_xyz = center_xyz.transpose(2, 1).contiguous() + + assert xyz.is_contiguous() # [B, N, 3] + assert center_xyz.is_contiguous() # [B, npoint, 3] + + center_xyz_device = center_xyz.get_device() + assert center_xyz_device == xyz.get_device(), \ + 'center_xyz and xyz should be put on the same device' + if torch.cuda.current_device() != center_xyz_device: + torch.cuda.set_device(center_xyz_device) + + B, npoint, _ = center_xyz.shape + N = xyz.shape[1] + + idx = center_xyz.new_zeros((B, npoint, k)).int() + dist2 = center_xyz.new_zeros((B, npoint, k)).float() + + ext_module.knn_forward( + xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k) + # idx shape to [B, k, npoint] + idx = idx.transpose(2, 1).contiguous() + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None + + +knn = KNN.apply diff --git a/annotator/uniformer/mmcv/ops/masked_conv.py b/annotator/uniformer/mmcv/ops/masked_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..cd514cc204c1d571ea5dc7e74b038c0f477a008b --- /dev/null +++ b/annotator/uniformer/mmcv/ops/masked_conv.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['masked_im2col_forward', 'masked_col2im_forward']) + + +class MaskedConv2dFunction(Function): + + @staticmethod + def symbolic(g, features, mask, weight, bias, padding, stride): + return g.op( + 'mmcv::MMCVMaskedConv2d', + features, + mask, + weight, + bias, + padding_i=padding, + stride_i=stride) + + @staticmethod + def forward(ctx, features, mask, weight, bias, padding=0, stride=1): + assert mask.dim() == 3 and mask.size(0) == 1 + assert features.dim() == 4 and features.size(0) == 1 + assert features.size()[2:] == mask.size()[1:] + pad_h, pad_w = _pair(padding) + stride_h, stride_w = _pair(stride) + if stride_h != 1 or stride_w != 1: + raise ValueError( + 'Stride could not only be 1 in masked_conv2d currently.') + out_channel, in_channel, kernel_h, kernel_w = weight.size() + + batch_size = features.size(0) + out_h = int( + math.floor((features.size(2) + 2 * pad_h - + (kernel_h - 1) - 1) / stride_h + 1)) + out_w = int( + math.floor((features.size(3) + 2 * pad_w - + (kernel_h - 1) - 1) / stride_w + 1)) + mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False) + output = features.new_zeros(batch_size, out_channel, out_h, out_w) + if mask_inds.numel() > 0: + mask_h_idx = mask_inds[:, 0].contiguous() + mask_w_idx = mask_inds[:, 1].contiguous() + data_col = features.new_zeros(in_channel * kernel_h * kernel_w, + mask_inds.size(0)) + ext_module.masked_im2col_forward( + features, + mask_h_idx, + mask_w_idx, + data_col, + kernel_h=kernel_h, + kernel_w=kernel_w, + pad_h=pad_h, + pad_w=pad_w) + + masked_output = torch.addmm(1, bias[:, None], 1, + weight.view(out_channel, -1), data_col) + ext_module.masked_col2im_forward( + masked_output, + mask_h_idx, + mask_w_idx, + output, + height=out_h, + width=out_w, + channels=out_channel) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + return (None, ) * 5 + + +masked_conv2d = MaskedConv2dFunction.apply + + +class MaskedConv2d(nn.Conv2d): + """A MaskedConv2d which inherits the official Conv2d. + + The masked forward doesn't implement the backward function and only + supports the stride parameter to be 1 currently. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super(MaskedConv2d, + self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, input, mask=None): + if mask is None: # fallback to the normal Conv2d + return super(MaskedConv2d, self).forward(input) + else: + return masked_conv2d(input, mask, self.weight, self.bias, + self.padding) diff --git a/annotator/uniformer/mmcv/ops/merge_cells.py b/annotator/uniformer/mmcv/ops/merge_cells.py new file mode 100644 index 0000000000000000000000000000000000000000..48ca8cc0a8aca8432835bd760c0403a3c35b34cf --- /dev/null +++ b/annotator/uniformer/mmcv/ops/merge_cells.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..cnn import ConvModule + + +class BaseMergeCell(nn.Module): + """The basic class for cells used in NAS-FPN and NAS-FCOS. + + BaseMergeCell takes 2 inputs. After applying convolution + on them, they are resized to the target size. Then, + they go through binary_op, which depends on the type of cell. + If with_out_conv is True, the result of output will go through + another convolution layer. + + Args: + in_channels (int): number of input channels in out_conv layer. + out_channels (int): number of output channels in out_conv layer. + with_out_conv (bool): Whether to use out_conv layer + out_conv_cfg (dict): Config dict for convolution layer, which should + contain "groups", "kernel_size", "padding", "bias" to build + out_conv layer. + out_norm_cfg (dict): Config dict for normalization layer in out_conv. + out_conv_order (tuple): The order of conv/norm/activation layers in + out_conv. + with_input1_conv (bool): Whether to use convolution on input1. + with_input2_conv (bool): Whether to use convolution on input2. + input_conv_cfg (dict): Config dict for building input1_conv layer and + input2_conv layer, which is expected to contain the type of + convolution. + Default: None, which means using conv2d. + input_norm_cfg (dict): Config dict for normalization layer in + input1_conv and input2_conv layer. Default: None. + upsample_mode (str): Interpolation method used to resize the output + of input1_conv and input2_conv to target size. Currently, we + support ['nearest', 'bilinear']. Default: 'nearest'. + """ + + def __init__(self, + fused_channels=256, + out_channels=256, + with_out_conv=True, + out_conv_cfg=dict( + groups=1, kernel_size=3, padding=1, bias=True), + out_norm_cfg=None, + out_conv_order=('act', 'conv', 'norm'), + with_input1_conv=False, + with_input2_conv=False, + input_conv_cfg=None, + input_norm_cfg=None, + upsample_mode='nearest'): + super(BaseMergeCell, self).__init__() + assert upsample_mode in ['nearest', 'bilinear'] + self.with_out_conv = with_out_conv + self.with_input1_conv = with_input1_conv + self.with_input2_conv = with_input2_conv + self.upsample_mode = upsample_mode + + if self.with_out_conv: + self.out_conv = ConvModule( + fused_channels, + out_channels, + **out_conv_cfg, + norm_cfg=out_norm_cfg, + order=out_conv_order) + + self.input1_conv = self._build_input_conv( + out_channels, input_conv_cfg, + input_norm_cfg) if with_input1_conv else nn.Sequential() + self.input2_conv = self._build_input_conv( + out_channels, input_conv_cfg, + input_norm_cfg) if with_input2_conv else nn.Sequential() + + def _build_input_conv(self, channel, conv_cfg, norm_cfg): + return ConvModule( + channel, + channel, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True) + + @abstractmethod + def _binary_op(self, x1, x2): + pass + + def _resize(self, x, size): + if x.shape[-2:] == size: + return x + elif x.shape[-2:] < size: + return F.interpolate(x, size=size, mode=self.upsample_mode) + else: + assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0 + kernel_size = x.shape[-1] // size[-1] + x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size) + return x + + def forward(self, x1, x2, out_size=None): + assert x1.shape[:2] == x2.shape[:2] + assert out_size is None or len(out_size) == 2 + if out_size is None: # resize to larger one + out_size = max(x1.size()[2:], x2.size()[2:]) + + x1 = self.input1_conv(x1) + x2 = self.input2_conv(x2) + + x1 = self._resize(x1, out_size) + x2 = self._resize(x2, out_size) + + x = self._binary_op(x1, x2) + if self.with_out_conv: + x = self.out_conv(x) + return x + + +class SumCell(BaseMergeCell): + + def __init__(self, in_channels, out_channels, **kwargs): + super(SumCell, self).__init__(in_channels, out_channels, **kwargs) + + def _binary_op(self, x1, x2): + return x1 + x2 + + +class ConcatCell(BaseMergeCell): + + def __init__(self, in_channels, out_channels, **kwargs): + super(ConcatCell, self).__init__(in_channels * 2, out_channels, + **kwargs) + + def _binary_op(self, x1, x2): + ret = torch.cat([x1, x2], dim=1) + return ret + + +class GlobalPoolingCell(BaseMergeCell): + + def __init__(self, in_channels=None, out_channels=None, **kwargs): + super().__init__(in_channels, out_channels, **kwargs) + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + + def _binary_op(self, x1, x2): + x2_att = self.global_pool(x2).sigmoid() + return x2 + x2_att * x1 diff --git a/annotator/uniformer/mmcv/ops/modulated_deform_conv.py b/annotator/uniformer/mmcv/ops/modulated_deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..75559579cf053abcc99538606cbb88c723faf783 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/modulated_deform_conv.py @@ -0,0 +1,282 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair, _single + +from annotator.uniformer.mmcv.utils import deprecated_api_warning +from ..cnn import CONV_LAYERS +from ..utils import ext_loader, print_log + +ext_module = ext_loader.load_ext( + '_ext', + ['modulated_deform_conv_forward', 'modulated_deform_conv_backward']) + + +class ModulatedDeformConv2dFunction(Function): + + @staticmethod + def symbolic(g, input, offset, mask, weight, bias, stride, padding, + dilation, groups, deform_groups): + input_tensors = [input, offset, mask, weight] + if bias is not None: + input_tensors.append(bias) + return g.op( + 'mmcv::MMCVModulatedDeformConv2d', + *input_tensors, + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + groups_i=groups, + deform_groups_i=deform_groups) + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deform_groups=1): + if input is not None and input.dim() != 4: + raise ValueError( + f'Expected 4D tensor as input, got {input.dim()}D tensor \ + instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deform_groups = deform_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(0) # fake tensor + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of model (float32), but "offset" is cast + # to float16 by nn.Conv2d automatically, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "offset", + # we cast weight and input to temporarily support fp16 and amp + # whatever the pytorch version is. + input = input.type_as(offset) + weight = weight.type_as(input) + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty( + ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + ext_module.modulated_deform_conv_forward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + output, + ctx._bufs[1], + kernel_h=weight.size(2), + kernel_w=weight.size(3), + stride_h=ctx.stride[0], + stride_w=ctx.stride[1], + pad_h=ctx.padding[0], + pad_w=ctx.padding[1], + dilation_h=ctx.dilation[0], + dilation_w=ctx.dilation[1], + group=ctx.groups, + deformable_group=ctx.deform_groups, + with_bias=ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + grad_output = grad_output.contiguous() + ext_module.modulated_deform_conv_backward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + ctx._bufs[1], + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + kernel_h=weight.size(2), + kernel_w=weight.size(3), + stride_h=ctx.stride[0], + stride_w=ctx.stride[1], + pad_h=ctx.padding[0], + pad_w=ctx.padding[1], + dilation_h=ctx.dilation[0], + dilation_w=ctx.dilation[1], + group=ctx.groups, + deformable_group=ctx.deform_groups, + with_bias=ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, + None, None, None, None, None) + + @staticmethod + def _output_size(ctx, input, weight): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = ctx.padding[d] + kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = ctx.stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + 'convolution input is too small (output would be ' + + 'x'.join(map(str, output_size)) + ')') + return output_size + + +modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply + + +class ModulatedDeformConv2d(nn.Module): + + @deprecated_api_warning({'deformable_groups': 'deform_groups'}, + cls_name='ModulatedDeformConv2d') + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deform_groups=1, + bias=True): + super(ModulatedDeformConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deform_groups = deform_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, + *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, + self.dilation, self.groups, + self.deform_groups) + + +@CONV_LAYERS.register_module('DCNv2') +class ModulatedDeformConv2dPack(ModulatedDeformConv2d): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv + layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int): Same as nn.Conv2d, while tuple is not supported. + padding (int): Same as nn.Conv2d, while tuple is not supported. + dilation (int): Same as nn.Conv2d, while tuple is not supported. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConv2dPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, + self.dilation, self.groups, + self.deform_groups) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + + if version is None or version < 2: + # the key is different in early versions + # In version < 2, ModulatedDeformConvPack + # loads previous benchmark models. + if (prefix + 'conv_offset.weight' not in state_dict + and prefix[:-1] + '_offset.weight' in state_dict): + state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( + prefix[:-1] + '_offset.weight') + if (prefix + 'conv_offset.bias' not in state_dict + and prefix[:-1] + '_offset.bias' in state_dict): + state_dict[prefix + + 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + + '_offset.bias') + + if version is not None and version > 1: + print_log( + f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to ' + 'version 2.', + logger='root') + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) diff --git a/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py b/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..c52dda18b41705705b47dd0e995b124048c16fba --- /dev/null +++ b/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py @@ -0,0 +1,358 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd.function import Function, once_differentiable + +from annotator.uniformer.mmcv import deprecated_api_warning +from annotator.uniformer.mmcv.cnn import constant_init, xavier_init +from annotator.uniformer.mmcv.cnn.bricks.registry import ATTENTION +from annotator.uniformer.mmcv.runner import BaseModule +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) + + +class MultiScaleDeformableAttnFunction(Function): + + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights, im2col_step): + """GPU version of multi-scale deformable attention. + + Args: + value (Tensor): The value has shape + (bs, num_keys, mum_heads, embed_dims//num_heads) + value_spatial_shapes (Tensor): Spatial shape of + each feature map, has shape (num_levels, 2), + last dimension 2 represent (h, w) + sampling_locations (Tensor): The location of sampling points, + has shape + (bs ,num_queries, num_heads, num_levels, num_points, 2), + the last dimension 2 represent (x, y). + attention_weights (Tensor): The weight of sampling points used + when calculate the attention, has shape + (bs ,num_queries, num_heads, num_levels, num_points), + im2col_step (Tensor): The step used in image to column. + + Returns: + Tensor: has shape (bs, num_queries, embed_dims) + """ + + ctx.im2col_step = im2col_step + output = ext_module.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step=ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, + value_level_start_index, sampling_locations, + attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + """GPU version of backward function. + + Args: + grad_output (Tensor): Gradient + of output tensor of forward. + + Returns: + Tuple[Tensor]: Gradient + of input tensors in forward. + """ + value, value_spatial_shapes, value_level_start_index,\ + sampling_locations, attention_weights = ctx.saved_tensors + grad_value = torch.zeros_like(value) + grad_sampling_loc = torch.zeros_like(sampling_locations) + grad_attn_weight = torch.zeros_like(attention_weights) + + ext_module.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output.contiguous(), + grad_value, + grad_sampling_loc, + grad_attn_weight, + im2col_step=ctx.im2col_step) + + return grad_value, None, None, \ + grad_sampling_loc, grad_attn_weight, None + + +def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, + sampling_locations, attention_weights): + """CPU version of multi-scale deformable attention. + + Args: + value (Tensor): The value has shape + (bs, num_keys, mum_heads, embed_dims//num_heads) + value_spatial_shapes (Tensor): Spatial shape of + each feature map, has shape (num_levels, 2), + last dimension 2 represent (h, w) + sampling_locations (Tensor): The location of sampling points, + has shape + (bs ,num_queries, num_heads, num_levels, num_points, 2), + the last dimension 2 represent (x, y). + attention_weights (Tensor): The weight of sampling points used + when calculate the attention, has shape + (bs ,num_queries, num_heads, num_levels, num_points), + + Returns: + Tensor: has shape (bs, num_queries, embed_dims) + """ + + bs, _, num_heads, embed_dims = value.shape + _, num_queries, num_heads, num_levels, num_points, _ =\ + sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], + dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (H_, W_) in enumerate(value_spatial_shapes): + # bs, H_*W_, num_heads, embed_dims -> + # bs, H_*W_, num_heads*embed_dims -> + # bs, num_heads*embed_dims, H_*W_ -> + # bs*num_heads, embed_dims, H_, W_ + value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape( + bs * num_heads, embed_dims, H_, W_) + # bs, num_queries, num_heads, num_points, 2 -> + # bs, num_heads, num_queries, num_points, 2 -> + # bs*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, + level].transpose(1, 2).flatten(0, 1) + # bs*num_heads, embed_dims, num_queries, num_points + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (bs, num_queries, num_heads, num_levels, num_points) -> + # (bs, num_heads, num_queries, num_levels, num_points) -> + # (bs, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * + attention_weights).sum(-1).view(bs, num_heads * embed_dims, + num_queries) + return output.transpose(1, 2).contiguous() + + +@ATTENTION.register_module() +class MultiScaleDeformableAttention(BaseModule): + """An attention module used in Deformable-Detr. + + `Deformable DETR: Deformable Transformers for End-to-End Object Detection. + `_. + + Args: + embed_dims (int): The embedding dimension of Attention. + Default: 256. + num_heads (int): Parallel attention heads. Default: 64. + num_levels (int): The number of feature map used in + Attention. Default: 4. + num_points (int): The number of sampling points for + each query in each head. Default: 4. + im2col_step (int): The step used in image_to_column. + Default: 64. + dropout (float): A Dropout layer on `inp_identity`. + Default: 0.1. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default to False. + norm_cfg (dict): Config dict for normalization layer. + Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims=256, + num_heads=8, + num_levels=4, + num_points=4, + im2col_step=64, + dropout=0.1, + batch_first=False, + norm_cfg=None, + init_cfg=None): + super().__init__(init_cfg) + if embed_dims % num_heads != 0: + raise ValueError(f'embed_dims must be divisible by num_heads, ' + f'but got {embed_dims} and {num_heads}') + dim_per_head = embed_dims // num_heads + self.norm_cfg = norm_cfg + self.dropout = nn.Dropout(dropout) + self.batch_first = batch_first + + # you'd better set dim_per_head to a power of 2 + # which is more efficient in the CUDA implementation + def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + 'invalid input for _is_power_of_2: {} (type: {})'.format( + n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + if not _is_power_of_2(dim_per_head): + warnings.warn( + "You'd better set embed_dims in " + 'MultiScaleDeformAttention to make ' + 'the dimension of each attention head a power of 2 ' + 'which is more efficient in our CUDA implementation.') + + self.im2col_step = im2col_step + self.embed_dims = embed_dims + self.num_levels = num_levels + self.num_heads = num_heads + self.num_points = num_points + self.sampling_offsets = nn.Linear( + embed_dims, num_heads * num_levels * num_points * 2) + self.attention_weights = nn.Linear(embed_dims, + num_heads * num_levels * num_points) + self.value_proj = nn.Linear(embed_dims, embed_dims) + self.output_proj = nn.Linear(embed_dims, embed_dims) + self.init_weights() + + def init_weights(self): + """Default initialization for Parameters of Module.""" + constant_init(self.sampling_offsets, 0.) + thetas = torch.arange( + self.num_heads, + dtype=torch.float32) * (2.0 * math.pi / self.num_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / + grid_init.abs().max(-1, keepdim=True)[0]).view( + self.num_heads, 1, 1, + 2).repeat(1, self.num_levels, self.num_points, 1) + for i in range(self.num_points): + grid_init[:, :, i, :] *= i + 1 + + self.sampling_offsets.bias.data = grid_init.view(-1) + constant_init(self.attention_weights, val=0., bias=0.) + xavier_init(self.value_proj, distribution='uniform', bias=0.) + xavier_init(self.output_proj, distribution='uniform', bias=0.) + self._is_init = True + + @deprecated_api_warning({'residual': 'identity'}, + cls_name='MultiScaleDeformableAttention') + def forward(self, + query, + key=None, + value=None, + identity=None, + query_pos=None, + key_padding_mask=None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + **kwargs): + """Forward Function of MultiScaleDeformAttention. + + Args: + query (Tensor): Query of Transformer with shape + (num_query, bs, embed_dims). + key (Tensor): The key tensor with shape + `(num_key, bs, embed_dims)`. + value (Tensor): The value tensor with shape + `(num_key, bs, embed_dims)`. + identity (Tensor): The tensor used for addition, with the + same shape as `query`. Default None. If None, + `query` will be used. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + key_pos (Tensor): The positional encoding for `key`. Default + None. + reference_points (Tensor): The normalized reference + points with shape (bs, num_query, num_levels, 2), + all elements is range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area. + or (N, Length_{query}, num_levels, 4), add + additional two dimensions is (w, h) to + form reference boxes. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_key]. + spatial_shapes (Tensor): Spatial shape of features in + different levels. With shape (num_levels, 2), + last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape ``(num_levels, )`` and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + + if value is None: + value = query + + if identity is None: + identity = query + if query_pos is not None: + query = query + query_pos + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, _ = query.shape + bs, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + value = value.view(bs, num_value, self.num_heads, -1) + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view(bs, num_query, + self.num_heads, + self.num_levels, + self.num_points) + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets \ + / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.num_points \ + * reference_points[:, :, None, :, None, 2:] \ + * 0.5 + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2 or 4, but get {reference_points.shape[-1]} instead.') + if torch.cuda.is_available() and value.is_cuda: + output = MultiScaleDeformableAttnFunction.apply( + value, spatial_shapes, level_start_index, sampling_locations, + attention_weights, self.im2col_step) + else: + output = multi_scale_deformable_attn_pytorch( + value, spatial_shapes, sampling_locations, attention_weights) + + output = self.output_proj(output) + + if not self.batch_first: + # (num_query, bs ,embed_dims) + output = output.permute(1, 0, 2) + + return self.dropout(output) + identity diff --git a/annotator/uniformer/mmcv/ops/nms.py b/annotator/uniformer/mmcv/ops/nms.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9634281f486ab284091786886854c451368052 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/nms.py @@ -0,0 +1,417 @@ +import os + +import numpy as np +import torch + +from annotator.uniformer.mmcv.utils import deprecated_api_warning +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated']) + + +# This function is modified from: https://github.com/pytorch/vision/ +class NMSop(torch.autograd.Function): + + @staticmethod + def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold, + max_num): + is_filtering_by_score = score_threshold > 0 + if is_filtering_by_score: + valid_mask = scores > score_threshold + bboxes, scores = bboxes[valid_mask], scores[valid_mask] + valid_inds = torch.nonzero( + valid_mask, as_tuple=False).squeeze(dim=1) + + inds = ext_module.nms( + bboxes, scores, iou_threshold=float(iou_threshold), offset=offset) + + if max_num > 0: + inds = inds[:max_num] + if is_filtering_by_score: + inds = valid_inds[inds] + return inds + + @staticmethod + def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold, + max_num): + from ..onnx import is_custom_op_loaded + has_custom_op = is_custom_op_loaded() + # TensorRT nms plugin is aligned with original nms in ONNXRuntime + is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT' + if has_custom_op and (not is_trt_backend): + return g.op( + 'mmcv::NonMaxSuppression', + bboxes, + scores, + iou_threshold_f=float(iou_threshold), + offset_i=int(offset)) + else: + from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze + from ..onnx.onnx_utils.symbolic_helper import _size_helper + + boxes = unsqueeze(g, bboxes, 0) + scores = unsqueeze(g, unsqueeze(g, scores, 0), 0) + + if max_num > 0: + max_num = g.op( + 'Constant', + value_t=torch.tensor(max_num, dtype=torch.long)) + else: + dim = g.op('Constant', value_t=torch.tensor(0)) + max_num = _size_helper(g, bboxes, dim) + max_output_per_class = max_num + iou_threshold = g.op( + 'Constant', + value_t=torch.tensor([iou_threshold], dtype=torch.float)) + score_threshold = g.op( + 'Constant', + value_t=torch.tensor([score_threshold], dtype=torch.float)) + nms_out = g.op('NonMaxSuppression', boxes, scores, + max_output_per_class, iou_threshold, + score_threshold) + return squeeze( + g, + select( + g, nms_out, 1, + g.op( + 'Constant', + value_t=torch.tensor([2], dtype=torch.long))), 1) + + +class SoftNMSop(torch.autograd.Function): + + @staticmethod + def forward(ctx, boxes, scores, iou_threshold, sigma, min_score, method, + offset): + dets = boxes.new_empty((boxes.size(0), 5), device='cpu') + inds = ext_module.softnms( + boxes.cpu(), + scores.cpu(), + dets.cpu(), + iou_threshold=float(iou_threshold), + sigma=float(sigma), + min_score=float(min_score), + method=int(method), + offset=int(offset)) + return dets, inds + + @staticmethod + def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method, + offset): + from packaging import version + assert version.parse(torch.__version__) >= version.parse('1.7.0') + nms_out = g.op( + 'mmcv::SoftNonMaxSuppression', + boxes, + scores, + iou_threshold_f=float(iou_threshold), + sigma_f=float(sigma), + min_score_f=float(min_score), + method_i=int(method), + offset_i=int(offset), + outputs=2) + return nms_out + + +@deprecated_api_warning({'iou_thr': 'iou_threshold'}) +def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1): + """Dispatch to either CPU or GPU NMS implementations. + + The input can be either torch tensor or numpy array. GPU NMS will be used + if the input is gpu tensor, otherwise CPU NMS + will be used. The returned type will always be the same as inputs. + + Arguments: + boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4). + scores (torch.Tensor or np.ndarray): scores in shape (N, ). + iou_threshold (float): IoU threshold for NMS. + offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset). + score_threshold (float): score threshold for NMS. + max_num (int): maximum number of boxes after NMS. + + Returns: + tuple: kept dets(boxes and scores) and indice, which is always the \ + same data type as the input. + + Example: + >>> boxes = np.array([[49.1, 32.4, 51.0, 35.9], + >>> [49.3, 32.9, 51.0, 35.3], + >>> [49.2, 31.8, 51.0, 35.4], + >>> [35.1, 11.5, 39.1, 15.7], + >>> [35.6, 11.8, 39.3, 14.2], + >>> [35.3, 11.5, 39.9, 14.5], + >>> [35.2, 11.7, 39.7, 15.7]], dtype=np.float32) + >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.5, 0.4, 0.3],\ + dtype=np.float32) + >>> iou_threshold = 0.6 + >>> dets, inds = nms(boxes, scores, iou_threshold) + >>> assert len(inds) == len(dets) == 3 + """ + assert isinstance(boxes, (torch.Tensor, np.ndarray)) + assert isinstance(scores, (torch.Tensor, np.ndarray)) + is_numpy = False + if isinstance(boxes, np.ndarray): + is_numpy = True + boxes = torch.from_numpy(boxes) + if isinstance(scores, np.ndarray): + scores = torch.from_numpy(scores) + assert boxes.size(1) == 4 + assert boxes.size(0) == scores.size(0) + assert offset in (0, 1) + + if torch.__version__ == 'parrots': + indata_list = [boxes, scores] + indata_dict = { + 'iou_threshold': float(iou_threshold), + 'offset': int(offset) + } + inds = ext_module.nms(*indata_list, **indata_dict) + else: + inds = NMSop.apply(boxes, scores, iou_threshold, offset, + score_threshold, max_num) + dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1) + if is_numpy: + dets = dets.cpu().numpy() + inds = inds.cpu().numpy() + return dets, inds + + +@deprecated_api_warning({'iou_thr': 'iou_threshold'}) +def soft_nms(boxes, + scores, + iou_threshold=0.3, + sigma=0.5, + min_score=1e-3, + method='linear', + offset=0): + """Dispatch to only CPU Soft NMS implementations. + + The input can be either a torch tensor or numpy array. + The returned type will always be the same as inputs. + + Arguments: + boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4). + scores (torch.Tensor or np.ndarray): scores in shape (N, ). + iou_threshold (float): IoU threshold for NMS. + sigma (float): hyperparameter for gaussian method + min_score (float): score filter threshold + method (str): either 'linear' or 'gaussian' + offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset). + + Returns: + tuple: kept dets(boxes and scores) and indice, which is always the \ + same data type as the input. + + Example: + >>> boxes = np.array([[4., 3., 5., 3.], + >>> [4., 3., 5., 4.], + >>> [3., 1., 3., 1.], + >>> [3., 1., 3., 1.], + >>> [3., 1., 3., 1.], + >>> [3., 1., 3., 1.]], dtype=np.float32) + >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.4, 0.0], dtype=np.float32) + >>> iou_threshold = 0.6 + >>> dets, inds = soft_nms(boxes, scores, iou_threshold, sigma=0.5) + >>> assert len(inds) == len(dets) == 5 + """ + + assert isinstance(boxes, (torch.Tensor, np.ndarray)) + assert isinstance(scores, (torch.Tensor, np.ndarray)) + is_numpy = False + if isinstance(boxes, np.ndarray): + is_numpy = True + boxes = torch.from_numpy(boxes) + if isinstance(scores, np.ndarray): + scores = torch.from_numpy(scores) + assert boxes.size(1) == 4 + assert boxes.size(0) == scores.size(0) + assert offset in (0, 1) + method_dict = {'naive': 0, 'linear': 1, 'gaussian': 2} + assert method in method_dict.keys() + + if torch.__version__ == 'parrots': + dets = boxes.new_empty((boxes.size(0), 5), device='cpu') + indata_list = [boxes.cpu(), scores.cpu(), dets.cpu()] + indata_dict = { + 'iou_threshold': float(iou_threshold), + 'sigma': float(sigma), + 'min_score': min_score, + 'method': method_dict[method], + 'offset': int(offset) + } + inds = ext_module.softnms(*indata_list, **indata_dict) + else: + dets, inds = SoftNMSop.apply(boxes.cpu(), scores.cpu(), + float(iou_threshold), float(sigma), + float(min_score), method_dict[method], + int(offset)) + + dets = dets[:inds.size(0)] + + if is_numpy: + dets = dets.cpu().numpy() + inds = inds.cpu().numpy() + return dets, inds + else: + return dets.to(device=boxes.device), inds.to(device=boxes.device) + + +def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): + """Performs non-maximum suppression in a batched fashion. + + Modified from https://github.com/pytorch/vision/blob + /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39. + In order to perform NMS independently per class, we add an offset to all + the boxes. The offset is dependent only on the class idx, and is large + enough so that boxes from different classes do not overlap. + + Arguments: + boxes (torch.Tensor): boxes in shape (N, 4). + scores (torch.Tensor): scores in shape (N, ). + idxs (torch.Tensor): each index value correspond to a bbox cluster, + and NMS will not be applied between elements of different idxs, + shape (N, ). + nms_cfg (dict): specify nms type and other parameters like iou_thr. + Possible keys includes the following. + + - iou_thr (float): IoU threshold used for NMS. + - split_thr (float): threshold number of boxes. In some cases the + number of boxes is large (e.g., 200k). To avoid OOM during + training, the users could set `split_thr` to a small value. + If the number of boxes is greater than the threshold, it will + perform NMS on each group of boxes separately and sequentially. + Defaults to 10000. + class_agnostic (bool): if true, nms is class agnostic, + i.e. IoU thresholding happens over all boxes, + regardless of the predicted class. + + Returns: + tuple: kept dets and indice. + """ + nms_cfg_ = nms_cfg.copy() + class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic) + if class_agnostic: + boxes_for_nms = boxes + else: + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) + boxes_for_nms = boxes + offsets[:, None] + + nms_type = nms_cfg_.pop('type', 'nms') + nms_op = eval(nms_type) + + split_thr = nms_cfg_.pop('split_thr', 10000) + # Won't split to multiple nms nodes when exporting to onnx + if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export(): + dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_) + boxes = boxes[keep] + # -1 indexing works abnormal in TensorRT + # This assumes `dets` has 5 dimensions where + # the last dimension is score. + # TODO: more elegant way to handle the dimension issue. + # Some type of nms would reweight the score, such as SoftNMS + scores = dets[:, 4] + else: + max_num = nms_cfg_.pop('max_num', -1) + total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) + # Some type of nms would reweight the score, such as SoftNMS + scores_after_nms = scores.new_zeros(scores.size()) + for id in torch.unique(idxs): + mask = (idxs == id).nonzero(as_tuple=False).view(-1) + dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_) + total_mask[mask[keep]] = True + scores_after_nms[mask[keep]] = dets[:, -1] + keep = total_mask.nonzero(as_tuple=False).view(-1) + + scores, inds = scores_after_nms[keep].sort(descending=True) + keep = keep[inds] + boxes = boxes[keep] + + if max_num > 0: + keep = keep[:max_num] + boxes = boxes[:max_num] + scores = scores[:max_num] + + return torch.cat([boxes, scores[:, None]], -1), keep + + +def nms_match(dets, iou_threshold): + """Matched dets into different groups by NMS. + + NMS match is Similar to NMS but when a bbox is suppressed, nms match will + record the indice of suppressed bbox and form a group with the indice of + kept bbox. In each group, indice is sorted as score order. + + Arguments: + dets (torch.Tensor | np.ndarray): Det boxes with scores, shape (N, 5). + iou_thr (float): IoU thresh for NMS. + + Returns: + List[torch.Tensor | np.ndarray]: The outer list corresponds different + matched group, the inner Tensor corresponds the indices for a group + in score order. + """ + if dets.shape[0] == 0: + matched = [] + else: + assert dets.shape[-1] == 5, 'inputs dets.shape should be (N, 5), ' \ + f'but get {dets.shape}' + if isinstance(dets, torch.Tensor): + dets_t = dets.detach().cpu() + else: + dets_t = torch.from_numpy(dets) + indata_list = [dets_t] + indata_dict = {'iou_threshold': float(iou_threshold)} + matched = ext_module.nms_match(*indata_list, **indata_dict) + if torch.__version__ == 'parrots': + matched = matched.tolist() + + if isinstance(dets, torch.Tensor): + return [dets.new_tensor(m, dtype=torch.long) for m in matched] + else: + return [np.array(m, dtype=np.int) for m in matched] + + +def nms_rotated(dets, scores, iou_threshold, labels=None): + """Performs non-maximum suppression (NMS) on the rotated boxes according to + their intersection-over-union (IoU). + + Rotated NMS iteratively removes lower scoring rotated boxes which have an + IoU greater than iou_threshold with another (higher scoring) rotated box. + + Args: + boxes (Tensor): Rotated boxes in shape (N, 5). They are expected to \ + be in (x_ctr, y_ctr, width, height, angle_radian) format. + scores (Tensor): scores in shape (N, ). + iou_threshold (float): IoU thresh for NMS. + labels (Tensor): boxes' label in shape (N,). + + Returns: + tuple: kept dets(boxes and scores) and indice, which is always the \ + same data type as the input. + """ + if dets.shape[0] == 0: + return dets, None + multi_label = labels is not None + if multi_label: + dets_wl = torch.cat((dets, labels.unsqueeze(1)), 1) + else: + dets_wl = dets + _, order = scores.sort(0, descending=True) + dets_sorted = dets_wl.index_select(0, order) + + if torch.__version__ == 'parrots': + keep_inds = ext_module.nms_rotated( + dets_wl, + scores, + order, + dets_sorted, + iou_threshold=iou_threshold, + multi_label=multi_label) + else: + keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted, + iou_threshold, multi_label) + dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)), + dim=1) + return dets, keep_inds diff --git a/annotator/uniformer/mmcv/ops/pixel_group.py b/annotator/uniformer/mmcv/ops/pixel_group.py new file mode 100644 index 0000000000000000000000000000000000000000..2143c75f835a467c802fc3c37ecd3ac0f85bcda4 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/pixel_group.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['pixel_group']) + + +def pixel_group(score, mask, embedding, kernel_label, kernel_contour, + kernel_region_num, distance_threshold): + """Group pixels into text instances, which is widely used text detection + methods. + + Arguments: + score (np.array or Tensor): The foreground score with size hxw. + mask (np.array or Tensor): The foreground mask with size hxw. + embedding (np.array or Tensor): The embedding with size hxwxc to + distinguish instances. + kernel_label (np.array or Tensor): The instance kernel index with + size hxw. + kernel_contour (np.array or Tensor): The kernel contour with size hxw. + kernel_region_num (int): The instance kernel region number. + distance_threshold (float): The embedding distance threshold between + kernel and pixel in one instance. + + Returns: + pixel_assignment (List[List[float]]): The instance coordinate list. + Each element consists of averaged confidence, pixel number, and + coordinates (x_i, y_i for all pixels) in order. + """ + assert isinstance(score, (torch.Tensor, np.ndarray)) + assert isinstance(mask, (torch.Tensor, np.ndarray)) + assert isinstance(embedding, (torch.Tensor, np.ndarray)) + assert isinstance(kernel_label, (torch.Tensor, np.ndarray)) + assert isinstance(kernel_contour, (torch.Tensor, np.ndarray)) + assert isinstance(kernel_region_num, int) + assert isinstance(distance_threshold, float) + + if isinstance(score, np.ndarray): + score = torch.from_numpy(score) + if isinstance(mask, np.ndarray): + mask = torch.from_numpy(mask) + if isinstance(embedding, np.ndarray): + embedding = torch.from_numpy(embedding) + if isinstance(kernel_label, np.ndarray): + kernel_label = torch.from_numpy(kernel_label) + if isinstance(kernel_contour, np.ndarray): + kernel_contour = torch.from_numpy(kernel_contour) + + if torch.__version__ == 'parrots': + label = ext_module.pixel_group( + score, + mask, + embedding, + kernel_label, + kernel_contour, + kernel_region_num=kernel_region_num, + distance_threshold=distance_threshold) + label = label.tolist() + label = label[0] + list_index = kernel_region_num + pixel_assignment = [] + for x in range(kernel_region_num): + pixel_assignment.append( + np.array( + label[list_index:list_index + int(label[x])], + dtype=np.float)) + list_index = list_index + int(label[x]) + else: + pixel_assignment = ext_module.pixel_group(score, mask, embedding, + kernel_label, kernel_contour, + kernel_region_num, + distance_threshold) + return pixel_assignment diff --git a/annotator/uniformer/mmcv/ops/point_sample.py b/annotator/uniformer/mmcv/ops/point_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..267f4b3c56630acd85f9bdc630b7be09abab0aba --- /dev/null +++ b/annotator/uniformer/mmcv/ops/point_sample.py @@ -0,0 +1,336 @@ +# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa + +from os import path as osp + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair +from torch.onnx.operators import shape_as_tensor + + +def bilinear_grid_sample(im, grid, align_corners=False): + """Given an input and a flow-field grid, computes the output using input + values and pixel locations from grid. Supported only bilinear interpolation + method to sample the input pixels. + + Args: + im (torch.Tensor): Input feature map, shape (N, C, H, W) + grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) + align_corners {bool}: If set to True, the extrema (-1 and 1) are + considered as referring to the center points of the input’s + corner pixels. If set to False, they are instead considered as + referring to the corner points of the input’s corner pixels, + making the sampling more resolution agnostic. + Returns: + torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) + """ + n, c, h, w = im.shape + gn, gh, gw, _ = grid.shape + assert n == gn + + x = grid[:, :, :, 0] + y = grid[:, :, :, 1] + + if align_corners: + x = ((x + 1) / 2) * (w - 1) + y = ((y + 1) / 2) * (h - 1) + else: + x = ((x + 1) * w - 1) / 2 + y = ((y + 1) * h - 1) / 2 + + x = x.view(n, -1) + y = y.view(n, -1) + + x0 = torch.floor(x).long() + y0 = torch.floor(y).long() + x1 = x0 + 1 + y1 = y0 + 1 + + wa = ((x1 - x) * (y1 - y)).unsqueeze(1) + wb = ((x1 - x) * (y - y0)).unsqueeze(1) + wc = ((x - x0) * (y1 - y)).unsqueeze(1) + wd = ((x - x0) * (y - y0)).unsqueeze(1) + + # Apply default for grid_sample function zero padding + im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0) + padded_h = h + 2 + padded_w = w + 2 + # save points positions after padding + x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 + + # Clip coordinates to padded image size + x0 = torch.where(x0 < 0, torch.tensor(0), x0) + x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0) + x1 = torch.where(x1 < 0, torch.tensor(0), x1) + x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1) + y0 = torch.where(y0 < 0, torch.tensor(0), y0) + y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0) + y1 = torch.where(y1 < 0, torch.tensor(0), y1) + y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1) + + im_padded = im_padded.view(n, c, -1) + + x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) + x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) + x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) + x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) + + Ia = torch.gather(im_padded, 2, x0_y0) + Ib = torch.gather(im_padded, 2, x0_y1) + Ic = torch.gather(im_padded, 2, x1_y0) + Id = torch.gather(im_padded, 2, x1_y1) + + return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) + + +def is_in_onnx_export_without_custom_ops(): + from annotator.uniformer.mmcv.ops import get_onnxruntime_op_path + ort_custom_op_path = get_onnxruntime_op_path() + return torch.onnx.is_in_onnx_export( + ) and not osp.exists(ort_custom_op_path) + + +def normalize(grid): + """Normalize input grid from [-1, 1] to [0, 1] + Args: + grid (Tensor): The grid to be normalize, range [-1, 1]. + Returns: + Tensor: Normalized grid, range [0, 1]. + """ + + return (grid + 1.0) / 2.0 + + +def denormalize(grid): + """Denormalize input grid from range [0, 1] to [-1, 1] + Args: + grid (Tensor): The grid to be denormalize, range [0, 1]. + Returns: + Tensor: Denormalized grid, range [-1, 1]. + """ + + return grid * 2.0 - 1.0 + + +def generate_grid(num_grid, size, device): + """Generate regular square grid of points in [0, 1] x [0, 1] coordinate + space. + + Args: + num_grid (int): The number of grids to sample, one for each region. + size (tuple(int, int)): The side size of the regular grid. + device (torch.device): Desired device of returned tensor. + + Returns: + (torch.Tensor): A tensor of shape (num_grid, size[0]*size[1], 2) that + contains coordinates for the regular grids. + """ + + affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device) + grid = F.affine_grid( + affine_trans, torch.Size((1, 1, *size)), align_corners=False) + grid = normalize(grid) + return grid.view(1, -1, 2).expand(num_grid, -1, -1) + + +def rel_roi_point_to_abs_img_point(rois, rel_roi_points): + """Convert roi based relative point coordinates to image based absolute + point coordinates. + + Args: + rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5) + rel_roi_points (Tensor): Point coordinates inside RoI, relative to + RoI, location, range (0, 1), shape (N, P, 2) + Returns: + Tensor: Image based absolute point coordinates, shape (N, P, 2) + """ + + with torch.no_grad(): + assert rel_roi_points.size(0) == rois.size(0) + assert rois.dim() == 2 + assert rel_roi_points.dim() == 3 + assert rel_roi_points.size(2) == 2 + # remove batch idx + if rois.size(1) == 5: + rois = rois[:, 1:] + abs_img_points = rel_roi_points.clone() + # To avoid an error during exporting to onnx use independent + # variables instead inplace computation + xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0]) + ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1]) + xs += rois[:, None, 0] + ys += rois[:, None, 1] + abs_img_points = torch.stack([xs, ys], dim=2) + return abs_img_points + + +def get_shape_from_feature_map(x): + """Get spatial resolution of input feature map considering exporting to + onnx mode. + + Args: + x (torch.Tensor): Input tensor, shape (N, C, H, W) + Returns: + torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2) + """ + if torch.onnx.is_in_onnx_export(): + img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to( + x.device).float() + else: + img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to( + x.device).float() + return img_shape + + +def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.): + """Convert image based absolute point coordinates to image based relative + coordinates for sampling. + + Args: + abs_img_points (Tensor): Image based absolute point coordinates, + shape (N, P, 2) + img (tuple/Tensor): (height, width) of image or feature map. + spatial_scale (float): Scale points by this factor. Default: 1. + + Returns: + Tensor: Image based relative point coordinates for sampling, + shape (N, P, 2) + """ + + assert (isinstance(img, tuple) and len(img) == 2) or \ + (isinstance(img, torch.Tensor) and len(img.shape) == 4) + + if isinstance(img, tuple): + h, w = img + scale = torch.tensor([w, h], + dtype=torch.float, + device=abs_img_points.device) + scale = scale.view(1, 1, 2) + else: + scale = get_shape_from_feature_map(img) + + return abs_img_points / scale * spatial_scale + + +def rel_roi_point_to_rel_img_point(rois, + rel_roi_points, + img, + spatial_scale=1.): + """Convert roi based relative point coordinates to image based absolute + point coordinates. + + Args: + rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5) + rel_roi_points (Tensor): Point coordinates inside RoI, relative to + RoI, location, range (0, 1), shape (N, P, 2) + img (tuple/Tensor): (height, width) of image or feature map. + spatial_scale (float): Scale points by this factor. Default: 1. + + Returns: + Tensor: Image based relative point coordinates for sampling, + shape (N, P, 2) + """ + + abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points) + rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img, + spatial_scale) + + return rel_img_point + + +def point_sample(input, points, align_corners=False, **kwargs): + """A wrapper around :func:`grid_sample` to support 3D point_coords tensors + Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to + lie inside ``[0, 1] x [0, 1]`` square. + + Args: + input (Tensor): Feature map, shape (N, C, H, W). + points (Tensor): Image based absolute point coordinates (normalized), + range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2). + align_corners (bool): Whether align_corners. Default: False + + Returns: + Tensor: Features of `point` on `input`, shape (N, C, P) or + (N, C, Hgrid, Wgrid). + """ + + add_dim = False + if points.dim() == 3: + add_dim = True + points = points.unsqueeze(2) + if is_in_onnx_export_without_custom_ops(): + # If custom ops for onnx runtime not compiled use python + # implementation of grid_sample function to make onnx graph + # with supported nodes + output = bilinear_grid_sample( + input, denormalize(points), align_corners=align_corners) + else: + output = F.grid_sample( + input, denormalize(points), align_corners=align_corners, **kwargs) + if add_dim: + output = output.squeeze(3) + return output + + +class SimpleRoIAlign(nn.Module): + + def __init__(self, output_size, spatial_scale, aligned=True): + """Simple RoI align in PointRend, faster than standard RoIAlign. + + Args: + output_size (tuple[int]): h, w + spatial_scale (float): scale the input boxes by this number + aligned (bool): if False, use the legacy implementation in + MMDetection, align_corners=True will be used in F.grid_sample. + If True, align the results more perfectly. + """ + + super(SimpleRoIAlign, self).__init__() + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + # to be consistent with other RoI ops + self.use_torchvision = False + self.aligned = aligned + + def forward(self, features, rois): + num_imgs = features.size(0) + num_rois = rois.size(0) + rel_roi_points = generate_grid( + num_rois, self.output_size, device=rois.device) + + if torch.onnx.is_in_onnx_export(): + rel_img_points = rel_roi_point_to_rel_img_point( + rois, rel_roi_points, features, self.spatial_scale) + rel_img_points = rel_img_points.reshape(num_imgs, -1, + *rel_img_points.shape[1:]) + point_feats = point_sample( + features, rel_img_points, align_corners=not self.aligned) + point_feats = point_feats.transpose(1, 2) + else: + point_feats = [] + for batch_ind in range(num_imgs): + # unravel batch dim + feat = features[batch_ind].unsqueeze(0) + inds = (rois[:, 0].long() == batch_ind) + if inds.any(): + rel_img_points = rel_roi_point_to_rel_img_point( + rois[inds], rel_roi_points[inds], feat, + self.spatial_scale).unsqueeze(0) + point_feat = point_sample( + feat, rel_img_points, align_corners=not self.aligned) + point_feat = point_feat.squeeze(0).transpose(0, 1) + point_feats.append(point_feat) + + point_feats = torch.cat(point_feats, dim=0) + + channels = features.size(1) + roi_feats = point_feats.reshape(num_rois, channels, *self.output_size) + + return roi_feats + + def __repr__(self): + format_str = self.__class__.__name__ + format_str += '(output_size={}, spatial_scale={}'.format( + self.output_size, self.spatial_scale) + return format_str diff --git a/annotator/uniformer/mmcv/ops/points_in_boxes.py b/annotator/uniformer/mmcv/ops/points_in_boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..4003173a53052161dbcd687a2fa1d755642fdab8 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/points_in_boxes.py @@ -0,0 +1,133 @@ +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward', + 'points_in_boxes_all_forward' +]) + + +def points_in_boxes_part(points, boxes): + """Find the box in which each point is (CUDA). + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in + LiDAR/DEPTH coordinate, (x, y, z) is the bottom center + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M), default background = -1 + """ + assert points.shape[0] == boxes.shape[0], \ + 'Points and boxes should have the same batch size, ' \ + f'but got {points.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + 'boxes dimension should be 7, ' \ + f'but got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + 'points dimension should be 3, ' \ + f'but got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + + box_idxs_of_pts = points.new_zeros((batch_size, num_points), + dtype=torch.int).fill_(-1) + + # If manually put the tensor 'points' or 'boxes' on a device + # which is not the current device, some temporary variables + # will be created on the current device in the cuda op, + # and the output will be incorrect. + # Therefore, we force the current device to be the same + # as the device of the tensors if it was not. + # Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305 + # for the incorrect output before the fix. + points_device = points.get_device() + assert points_device == boxes.get_device(), \ + 'Points and boxes should be put on the same device' + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + + ext_module.points_in_boxes_part_forward(boxes.contiguous(), + points.contiguous(), + box_idxs_of_pts) + + return box_idxs_of_pts + + +def points_in_boxes_cpu(points, boxes): + """Find all boxes in which each point is (CPU). The CPU version of + :meth:`points_in_boxes_all`. + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in + LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz], + (x, y, z) is the bottom center. + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0. + """ + assert points.shape[0] == boxes.shape[0], \ + 'Points and boxes should have the same batch size, ' \ + f'but got {points.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + 'boxes dimension should be 7, ' \ + f'but got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + 'points dimension should be 3, ' \ + f'but got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + num_boxes = boxes.shape[1] + + point_indices = points.new_zeros((batch_size, num_boxes, num_points), + dtype=torch.int) + for b in range(batch_size): + ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(), + points[b].float().contiguous(), + point_indices[b]) + point_indices = point_indices.transpose(1, 2) + + return point_indices + + +def points_in_boxes_all(points, boxes): + """Find all boxes in which each point is (CUDA). + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz], + (x, y, z) is the bottom center. + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0. + """ + assert boxes.shape[0] == points.shape[0], \ + 'Points and boxes should have the same batch size, ' \ + f'but got {boxes.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + 'boxes dimension should be 7, ' \ + f'but got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + 'points dimension should be 3, ' \ + f'but got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + num_boxes = boxes.shape[1] + + box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes), + dtype=torch.int).fill_(0) + + # Same reason as line 25-32 + points_device = points.get_device() + assert points_device == boxes.get_device(), \ + 'Points and boxes should be put on the same device' + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + + ext_module.points_in_boxes_all_forward(boxes.contiguous(), + points.contiguous(), + box_idxs_of_pts) + + return box_idxs_of_pts diff --git a/annotator/uniformer/mmcv/ops/points_sampler.py b/annotator/uniformer/mmcv/ops/points_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a802a74fd6c3610d9ae178e6201f47423eca7ad1 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/points_sampler.py @@ -0,0 +1,177 @@ +from typing import List + +import torch +from torch import nn as nn + +from annotator.uniformer.mmcv.runner import force_fp32 +from .furthest_point_sample import (furthest_point_sample, + furthest_point_sample_with_dist) + + +def calc_square_dist(point_feat_a, point_feat_b, norm=True): + """Calculating square distance between a and b. + + Args: + point_feat_a (Tensor): (B, N, C) Feature vector of each point. + point_feat_b (Tensor): (B, M, C) Feature vector of each point. + norm (Bool, optional): Whether to normalize the distance. + Default: True. + + Returns: + Tensor: (B, N, M) Distance between each pair points. + """ + num_channel = point_feat_a.shape[-1] + # [bs, n, 1] + a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1) + # [bs, 1, m] + b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1) + + corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2)) + + dist = a_square + b_square - 2 * corr_matrix + if norm: + dist = torch.sqrt(dist) / num_channel + return dist + + +def get_sampler_cls(sampler_type): + """Get the type and mode of points sampler. + + Args: + sampler_type (str): The type of points sampler. + The valid value are "D-FPS", "F-FPS", or "FS". + + Returns: + class: Points sampler type. + """ + sampler_mappings = { + 'D-FPS': DFPSSampler, + 'F-FPS': FFPSSampler, + 'FS': FSSampler, + } + try: + return sampler_mappings[sampler_type] + except KeyError: + raise KeyError( + f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \ + {sampler_type}') + + +class PointsSampler(nn.Module): + """Points sampling. + + Args: + num_point (list[int]): Number of sample points. + fps_mod_list (list[str], optional): Type of FPS method, valid mod + ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS']. + F-FPS: using feature distances for FPS. + D-FPS: using Euclidean distances of points for FPS. + FS: using F-FPS and D-FPS simultaneously. + fps_sample_range_list (list[int], optional): + Range of points to apply FPS. Default: [-1]. + """ + + def __init__(self, + num_point: List[int], + fps_mod_list: List[str] = ['D-FPS'], + fps_sample_range_list: List[int] = [-1]): + super().__init__() + # FPS would be applied to different fps_mod in the list, + # so the length of the num_point should be equal to + # fps_mod_list and fps_sample_range_list. + assert len(num_point) == len(fps_mod_list) == len( + fps_sample_range_list) + self.num_point = num_point + self.fps_sample_range_list = fps_sample_range_list + self.samplers = nn.ModuleList() + for fps_mod in fps_mod_list: + self.samplers.append(get_sampler_cls(fps_mod)()) + self.fp16_enabled = False + + @force_fp32() + def forward(self, points_xyz, features): + """ + Args: + points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. + features (Tensor): (B, C, N) Descriptors of the features. + + Returns: + Tensor: (B, npoint, sample_num) Indices of sampled points. + """ + indices = [] + last_fps_end_index = 0 + + for fps_sample_range, sampler, npoint in zip( + self.fps_sample_range_list, self.samplers, self.num_point): + assert fps_sample_range < points_xyz.shape[1] + + if fps_sample_range == -1: + sample_points_xyz = points_xyz[:, last_fps_end_index:] + if features is not None: + sample_features = features[:, :, last_fps_end_index:] + else: + sample_features = None + else: + sample_points_xyz = \ + points_xyz[:, last_fps_end_index:fps_sample_range] + if features is not None: + sample_features = features[:, :, last_fps_end_index: + fps_sample_range] + else: + sample_features = None + + fps_idx = sampler(sample_points_xyz.contiguous(), sample_features, + npoint) + + indices.append(fps_idx + last_fps_end_index) + last_fps_end_index += fps_sample_range + indices = torch.cat(indices, dim=1) + + return indices + + +class DFPSSampler(nn.Module): + """Using Euclidean distances of points for FPS.""" + + def __init__(self): + super().__init__() + + def forward(self, points, features, npoint): + """Sampling points with D-FPS.""" + fps_idx = furthest_point_sample(points.contiguous(), npoint) + return fps_idx + + +class FFPSSampler(nn.Module): + """Using feature distances for FPS.""" + + def __init__(self): + super().__init__() + + def forward(self, points, features, npoint): + """Sampling points with F-FPS.""" + assert features is not None, \ + 'feature input to FFPS_Sampler should not be None' + features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2) + features_dist = calc_square_dist( + features_for_fps, features_for_fps, norm=False) + fps_idx = furthest_point_sample_with_dist(features_dist, npoint) + return fps_idx + + +class FSSampler(nn.Module): + """Using F-FPS and D-FPS simultaneously.""" + + def __init__(self): + super().__init__() + + def forward(self, points, features, npoint): + """Sampling points with FS_Sampling.""" + assert features is not None, \ + 'feature input to FS_Sampler should not be None' + ffps_sampler = FFPSSampler() + dfps_sampler = DFPSSampler() + fps_idx_ffps = ffps_sampler(points, features, npoint) + fps_idx_dfps = dfps_sampler(points, features, npoint) + fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1) + return fps_idx diff --git a/annotator/uniformer/mmcv/ops/psa_mask.py b/annotator/uniformer/mmcv/ops/psa_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf14e62b50e8d4dd6856c94333c703bcc4c9ab6 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/psa_mask.py @@ -0,0 +1,92 @@ +# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa +from torch import nn +from torch.autograd import Function +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', + ['psamask_forward', 'psamask_backward']) + + +class PSAMaskFunction(Function): + + @staticmethod + def symbolic(g, input, psa_type, mask_size): + return g.op( + 'mmcv::MMCVPSAMask', + input, + psa_type_i=psa_type, + mask_size_i=mask_size) + + @staticmethod + def forward(ctx, input, psa_type, mask_size): + ctx.psa_type = psa_type + ctx.mask_size = _pair(mask_size) + ctx.save_for_backward(input) + + h_mask, w_mask = ctx.mask_size + batch_size, channels, h_feature, w_feature = input.size() + assert channels == h_mask * w_mask + output = input.new_zeros( + (batch_size, h_feature * w_feature, h_feature, w_feature)) + + ext_module.psamask_forward( + input, + output, + psa_type=psa_type, + num_=batch_size, + h_feature=h_feature, + w_feature=w_feature, + h_mask=h_mask, + w_mask=w_mask, + half_h_mask=(h_mask - 1) // 2, + half_w_mask=(w_mask - 1) // 2) + return output + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_tensors[0] + psa_type = ctx.psa_type + h_mask, w_mask = ctx.mask_size + batch_size, channels, h_feature, w_feature = input.size() + grad_input = grad_output.new_zeros( + (batch_size, channels, h_feature, w_feature)) + ext_module.psamask_backward( + grad_output, + grad_input, + psa_type=psa_type, + num_=batch_size, + h_feature=h_feature, + w_feature=w_feature, + h_mask=h_mask, + w_mask=w_mask, + half_h_mask=(h_mask - 1) // 2, + half_w_mask=(w_mask - 1) // 2) + return grad_input, None, None, None + + +psa_mask = PSAMaskFunction.apply + + +class PSAMask(nn.Module): + + def __init__(self, psa_type, mask_size=None): + super(PSAMask, self).__init__() + assert psa_type in ['collect', 'distribute'] + if psa_type == 'collect': + psa_type_enum = 0 + else: + psa_type_enum = 1 + self.psa_type_enum = psa_type_enum + self.mask_size = mask_size + self.psa_type = psa_type + + def forward(self, input): + return psa_mask(input, self.psa_type_enum, self.mask_size) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(psa_type={self.psa_type}, ' + s += f'mask_size={self.mask_size})' + return s diff --git a/annotator/uniformer/mmcv/ops/roi_align.py b/annotator/uniformer/mmcv/ops/roi_align.py new file mode 100644 index 0000000000000000000000000000000000000000..0755aefc66e67233ceae0f4b77948301c443e9fb --- /dev/null +++ b/annotator/uniformer/mmcv/ops/roi_align.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import deprecated_api_warning, ext_loader + +ext_module = ext_loader.load_ext('_ext', + ['roi_align_forward', 'roi_align_backward']) + + +class RoIAlignFunction(Function): + + @staticmethod + def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio, + pool_mode, aligned): + from ..onnx import is_custom_op_loaded + has_custom_op = is_custom_op_loaded() + if has_custom_op: + return g.op( + 'mmcv::MMCVRoiAlign', + input, + rois, + output_height_i=output_size[0], + output_width_i=output_size[1], + spatial_scale_f=spatial_scale, + sampling_ratio_i=sampling_ratio, + mode_s=pool_mode, + aligned_i=aligned) + else: + from torch.onnx.symbolic_opset9 import sub, squeeze + from torch.onnx.symbolic_helper import _slice_helper + from torch.onnx import TensorProtoDataType + # batch_indices = rois[:, 0].long() + batch_indices = _slice_helper( + g, rois, axes=[1], starts=[0], ends=[1]) + batch_indices = squeeze(g, batch_indices, 1) + batch_indices = g.op( + 'Cast', batch_indices, to_i=TensorProtoDataType.INT64) + # rois = rois[:, 1:] + rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) + if aligned: + # rois -= 0.5/spatial_scale + aligned_offset = g.op( + 'Constant', + value_t=torch.tensor([0.5 / spatial_scale], + dtype=torch.float32)) + rois = sub(g, rois, aligned_offset) + # roi align + return g.op( + 'RoiAlign', + input, + rois, + batch_indices, + output_height_i=output_size[0], + output_width_i=output_size[1], + spatial_scale_f=spatial_scale, + sampling_ratio_i=max(0, sampling_ratio), + mode_s=pool_mode) + + @staticmethod + def forward(ctx, + input, + rois, + output_size, + spatial_scale=1.0, + sampling_ratio=0, + pool_mode='avg', + aligned=True): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.sampling_ratio = sampling_ratio + assert pool_mode in ('max', 'avg') + ctx.pool_mode = 0 if pool_mode == 'max' else 1 + ctx.aligned = aligned + ctx.input_shape = input.size() + + assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!' + + output_shape = (rois.size(0), input.size(1), ctx.output_size[0], + ctx.output_size[1]) + output = input.new_zeros(output_shape) + if ctx.pool_mode == 0: + argmax_y = input.new_zeros(output_shape) + argmax_x = input.new_zeros(output_shape) + else: + argmax_y = input.new_zeros(0) + argmax_x = input.new_zeros(0) + + ext_module.roi_align_forward( + input, + rois, + output, + argmax_y, + argmax_x, + aligned_height=ctx.output_size[0], + aligned_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + pool_mode=ctx.pool_mode, + aligned=ctx.aligned) + + ctx.save_for_backward(rois, argmax_y, argmax_x) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, argmax_y, argmax_x = ctx.saved_tensors + grad_input = grad_output.new_zeros(ctx.input_shape) + # complex head architecture may cause grad_output uncontiguous. + grad_output = grad_output.contiguous() + ext_module.roi_align_backward( + grad_output, + rois, + argmax_y, + argmax_x, + grad_input, + aligned_height=ctx.output_size[0], + aligned_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + pool_mode=ctx.pool_mode, + aligned=ctx.aligned) + return grad_input, None, None, None, None, None, None + + +roi_align = RoIAlignFunction.apply + + +class RoIAlign(nn.Module): + """RoI align pooling layer. + + Args: + output_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sampling_ratio (int): number of inputs samples to take for each + output sample. 0 to take samples densely for current models. + pool_mode (str, 'avg' or 'max'): pooling mode in each bin. + aligned (bool): if False, use the legacy implementation in + MMDetection. If True, align the results more perfectly. + use_torchvision (bool): whether to use roi_align from torchvision. + + Note: + The implementation of RoIAlign when aligned=True is modified from + https://github.com/facebookresearch/detectron2/ + + The meaning of aligned=True: + + Given a continuous coordinate c, its two neighboring pixel + indices (in our pixel model) are computed by floor(c - 0.5) and + ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete + indices [0] and [1] (which are sampled from the underlying signal + at continuous coordinates 0.5 and 1.5). But the original roi_align + (aligned=False) does not subtract the 0.5 when computing + neighboring pixel indices and therefore it uses pixels with a + slightly incorrect alignment (relative to our pixel model) when + performing bilinear interpolation. + + With `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5 + prior to calling roi_align. This produces the correct neighbors; + + The difference does not make a difference to the model's + performance if ROIAlign is used together with conv layers. + """ + + @deprecated_api_warning( + { + 'out_size': 'output_size', + 'sample_num': 'sampling_ratio' + }, + cls_name='RoIAlign') + def __init__(self, + output_size, + spatial_scale=1.0, + sampling_ratio=0, + pool_mode='avg', + aligned=True, + use_torchvision=False): + super(RoIAlign, self).__init__() + + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + self.sampling_ratio = int(sampling_ratio) + self.pool_mode = pool_mode + self.aligned = aligned + self.use_torchvision = use_torchvision + + def forward(self, input, rois): + """ + Args: + input: NCHW images + rois: Bx5 boxes. First column is the index into N.\ + The other 4 columns are xyxy. + """ + if self.use_torchvision: + from torchvision.ops import roi_align as tv_roi_align + if 'aligned' in tv_roi_align.__code__.co_varnames: + return tv_roi_align(input, rois, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.aligned) + else: + if self.aligned: + rois -= rois.new_tensor([0.] + + [0.5 / self.spatial_scale] * 4) + return tv_roi_align(input, rois, self.output_size, + self.spatial_scale, self.sampling_ratio) + else: + return roi_align(input, rois, self.output_size, self.spatial_scale, + self.sampling_ratio, self.pool_mode, self.aligned) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(output_size={self.output_size}, ' + s += f'spatial_scale={self.spatial_scale}, ' + s += f'sampling_ratio={self.sampling_ratio}, ' + s += f'pool_mode={self.pool_mode}, ' + s += f'aligned={self.aligned}, ' + s += f'use_torchvision={self.use_torchvision})' + return s diff --git a/annotator/uniformer/mmcv/ops/roi_align_rotated.py b/annotator/uniformer/mmcv/ops/roi_align_rotated.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce4961a3555d4da8bc3e32f1f7d5ad50036587d --- /dev/null +++ b/annotator/uniformer/mmcv/ops/roi_align_rotated.py @@ -0,0 +1,177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward']) + + +class RoIAlignRotatedFunction(Function): + + @staticmethod + def symbolic(g, features, rois, out_size, spatial_scale, sample_num, + aligned, clockwise): + if isinstance(out_size, int): + out_h = out_size + out_w = out_size + elif isinstance(out_size, tuple): + assert len(out_size) == 2 + assert isinstance(out_size[0], int) + assert isinstance(out_size[1], int) + out_h, out_w = out_size + else: + raise TypeError( + '"out_size" must be an integer or tuple of integers') + return g.op( + 'mmcv::MMCVRoIAlignRotated', + features, + rois, + output_height_i=out_h, + output_width_i=out_h, + spatial_scale_f=spatial_scale, + sampling_ratio_i=sample_num, + aligned_i=aligned, + clockwise_i=clockwise) + + @staticmethod + def forward(ctx, + features, + rois, + out_size, + spatial_scale, + sample_num=0, + aligned=True, + clockwise=False): + if isinstance(out_size, int): + out_h = out_size + out_w = out_size + elif isinstance(out_size, tuple): + assert len(out_size) == 2 + assert isinstance(out_size[0], int) + assert isinstance(out_size[1], int) + out_h, out_w = out_size + else: + raise TypeError( + '"out_size" must be an integer or tuple of integers') + ctx.spatial_scale = spatial_scale + ctx.sample_num = sample_num + ctx.aligned = aligned + ctx.clockwise = clockwise + ctx.save_for_backward(rois) + ctx.feature_size = features.size() + + batch_size, num_channels, data_height, data_width = features.size() + num_rois = rois.size(0) + + output = features.new_zeros(num_rois, num_channels, out_h, out_w) + ext_module.roi_align_rotated_forward( + features, + rois, + output, + pooled_height=out_h, + pooled_width=out_w, + spatial_scale=spatial_scale, + sample_num=sample_num, + aligned=aligned, + clockwise=clockwise) + return output + + @staticmethod + def backward(ctx, grad_output): + feature_size = ctx.feature_size + spatial_scale = ctx.spatial_scale + aligned = ctx.aligned + clockwise = ctx.clockwise + sample_num = ctx.sample_num + rois = ctx.saved_tensors[0] + assert feature_size is not None + batch_size, num_channels, data_height, data_width = feature_size + + out_w = grad_output.size(3) + out_h = grad_output.size(2) + + grad_input = grad_rois = None + + if ctx.needs_input_grad[0]: + grad_input = rois.new_zeros(batch_size, num_channels, data_height, + data_width) + ext_module.roi_align_rotated_backward( + grad_output.contiguous(), + rois, + grad_input, + pooled_height=out_h, + pooled_width=out_w, + spatial_scale=spatial_scale, + sample_num=sample_num, + aligned=aligned, + clockwise=clockwise) + return grad_input, grad_rois, None, None, None, None, None + + +roi_align_rotated = RoIAlignRotatedFunction.apply + + +class RoIAlignRotated(nn.Module): + """RoI align pooling layer for rotated proposals. + + It accepts a feature map of shape (N, C, H, W) and rois with shape + (n, 6) with each roi decoded as (batch_index, center_x, center_y, + w, h, angle). The angle is in radian. + + Args: + out_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sample_num (int): number of inputs samples to take for each + output sample. 0 to take samples densely for current models. + aligned (bool): if False, use the legacy implementation in + MMDetection. If True, align the results more perfectly. + Default: True. + clockwise (bool): If True, the angle in each proposal follows a + clockwise fashion in image space, otherwise, the angle is + counterclockwise. Default: False. + + Note: + The implementation of RoIAlign when aligned=True is modified from + https://github.com/facebookresearch/detectron2/ + + The meaning of aligned=True: + + Given a continuous coordinate c, its two neighboring pixel + indices (in our pixel model) are computed by floor(c - 0.5) and + ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete + indices [0] and [1] (which are sampled from the underlying signal + at continuous coordinates 0.5 and 1.5). But the original roi_align + (aligned=False) does not subtract the 0.5 when computing + neighboring pixel indices and therefore it uses pixels with a + slightly incorrect alignment (relative to our pixel model) when + performing bilinear interpolation. + + With `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5 + prior to calling roi_align. This produces the correct neighbors; + + The difference does not make a difference to the model's + performance if ROIAlign is used together with conv layers. + """ + + def __init__(self, + out_size, + spatial_scale, + sample_num=0, + aligned=True, + clockwise=False): + super(RoIAlignRotated, self).__init__() + + self.out_size = out_size + self.spatial_scale = float(spatial_scale) + self.sample_num = int(sample_num) + self.aligned = aligned + self.clockwise = clockwise + + def forward(self, features, rois): + return RoIAlignRotatedFunction.apply(features, rois, self.out_size, + self.spatial_scale, + self.sample_num, self.aligned, + self.clockwise) diff --git a/annotator/uniformer/mmcv/ops/roi_pool.py b/annotator/uniformer/mmcv/ops/roi_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..d339d8f2941eabc1cbe181a9c6c5ab5ff4ff4e5f --- /dev/null +++ b/annotator/uniformer/mmcv/ops/roi_pool.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', + ['roi_pool_forward', 'roi_pool_backward']) + + +class RoIPoolFunction(Function): + + @staticmethod + def symbolic(g, input, rois, output_size, spatial_scale): + return g.op( + 'MaxRoiPool', + input, + rois, + pooled_shape_i=output_size, + spatial_scale_f=spatial_scale) + + @staticmethod + def forward(ctx, input, rois, output_size, spatial_scale=1.0): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.input_shape = input.size() + + assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!' + + output_shape = (rois.size(0), input.size(1), ctx.output_size[0], + ctx.output_size[1]) + output = input.new_zeros(output_shape) + argmax = input.new_zeros(output_shape, dtype=torch.int) + + ext_module.roi_pool_forward( + input, + rois, + output, + argmax, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale) + + ctx.save_for_backward(rois, argmax) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, argmax = ctx.saved_tensors + grad_input = grad_output.new_zeros(ctx.input_shape) + + ext_module.roi_pool_backward( + grad_output, + rois, + argmax, + grad_input, + pooled_height=ctx.output_size[0], + pooled_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale) + + return grad_input, None, None, None + + +roi_pool = RoIPoolFunction.apply + + +class RoIPool(nn.Module): + + def __init__(self, output_size, spatial_scale=1.0): + super(RoIPool, self).__init__() + + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + + def forward(self, input, rois): + return roi_pool(input, rois, self.output_size, self.spatial_scale) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(output_size={self.output_size}, ' + s += f'spatial_scale={self.spatial_scale})' + return s diff --git a/annotator/uniformer/mmcv/ops/roiaware_pool3d.py b/annotator/uniformer/mmcv/ops/roiaware_pool3d.py new file mode 100644 index 0000000000000000000000000000000000000000..291b0e5a9b692492c7d7e495ea639c46042e2f18 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/roiaware_pool3d.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn as nn +from torch.autograd import Function + +import annotator.uniformer.mmcv as mmcv +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward']) + + +class RoIAwarePool3d(nn.Module): + """Encode the geometry-specific features of each 3D proposal. + + Please refer to `PartA2 `_ for more + details. + + Args: + out_size (int or tuple): The size of output features. n or + [n1, n2, n3]. + max_pts_per_voxel (int, optional): The maximum number of points per + voxel. Default: 128. + mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'. + Default: 'max'. + """ + + def __init__(self, out_size, max_pts_per_voxel=128, mode='max'): + super().__init__() + + self.out_size = out_size + self.max_pts_per_voxel = max_pts_per_voxel + assert mode in ['max', 'avg'] + pool_mapping = {'max': 0, 'avg': 1} + self.mode = pool_mapping[mode] + + def forward(self, rois, pts, pts_feature): + """ + Args: + rois (torch.Tensor): [N, 7], in LiDAR coordinate, + (x, y, z) is the bottom center of rois. + pts (torch.Tensor): [npoints, 3], coordinates of input points. + pts_feature (torch.Tensor): [npoints, C], features of input points. + + Returns: + pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C] + """ + + return RoIAwarePool3dFunction.apply(rois, pts, pts_feature, + self.out_size, + self.max_pts_per_voxel, self.mode) + + +class RoIAwarePool3dFunction(Function): + + @staticmethod + def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel, + mode): + """ + Args: + rois (torch.Tensor): [N, 7], in LiDAR coordinate, + (x, y, z) is the bottom center of rois. + pts (torch.Tensor): [npoints, 3], coordinates of input points. + pts_feature (torch.Tensor): [npoints, C], features of input points. + out_size (int or tuple): The size of output features. n or + [n1, n2, n3]. + max_pts_per_voxel (int): The maximum number of points per voxel. + Default: 128. + mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average + pool). + + Returns: + pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C], output + pooled features. + """ + + if isinstance(out_size, int): + out_x = out_y = out_z = out_size + else: + assert len(out_size) == 3 + assert mmcv.is_tuple_of(out_size, int) + out_x, out_y, out_z = out_size + + num_rois = rois.shape[0] + num_channels = pts_feature.shape[-1] + num_pts = pts.shape[0] + + pooled_features = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, num_channels)) + argmax = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int) + pts_idx_of_voxels = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, max_pts_per_voxel), + dtype=torch.int) + + ext_module.roiaware_pool3d_forward(rois, pts, pts_feature, argmax, + pts_idx_of_voxels, pooled_features, + mode) + + ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode, + num_pts, num_channels) + return pooled_features + + @staticmethod + def backward(ctx, grad_out): + ret = ctx.roiaware_pool3d_for_backward + pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret + + grad_in = grad_out.new_zeros((num_pts, num_channels)) + ext_module.roiaware_pool3d_backward(pts_idx_of_voxels, argmax, + grad_out.contiguous(), grad_in, + mode) + + return None, None, grad_in, None, None, None diff --git a/annotator/uniformer/mmcv/ops/roipoint_pool3d.py b/annotator/uniformer/mmcv/ops/roipoint_pool3d.py new file mode 100644 index 0000000000000000000000000000000000000000..0a21412c0728431c04b84245bc2e3109eea9aefc --- /dev/null +++ b/annotator/uniformer/mmcv/ops/roipoint_pool3d.py @@ -0,0 +1,77 @@ +from torch import nn as nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward']) + + +class RoIPointPool3d(nn.Module): + """Encode the geometry-specific features of each 3D proposal. + + Please refer to `Paper of PartA2 `_ + for more details. + + Args: + num_sampled_points (int, optional): Number of samples in each roi. + Default: 512. + """ + + def __init__(self, num_sampled_points=512): + super().__init__() + self.num_sampled_points = num_sampled_points + + def forward(self, points, point_features, boxes3d): + """ + Args: + points (torch.Tensor): Input points whose shape is (B, N, C). + point_features (torch.Tensor): Features of input points whose shape + is (B, N, C). + boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). + + Returns: + pooled_features (torch.Tensor): The output pooled features whose + shape is (B, M, 512, 3 + C). + pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M). + """ + return RoIPointPool3dFunction.apply(points, point_features, boxes3d, + self.num_sampled_points) + + +class RoIPointPool3dFunction(Function): + + @staticmethod + def forward(ctx, points, point_features, boxes3d, num_sampled_points=512): + """ + Args: + points (torch.Tensor): Input points whose shape is (B, N, C). + point_features (torch.Tensor): Features of input points whose shape + is (B, N, C). + boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). + num_sampled_points (int, optional): The num of sampled points. + Default: 512. + + Returns: + pooled_features (torch.Tensor): The output pooled features whose + shape is (B, M, 512, 3 + C). + pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M). + """ + assert len(points.shape) == 3 and points.shape[2] == 3 + batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[ + 1], point_features.shape[2] + pooled_boxes3d = boxes3d.view(batch_size, -1, 7) + pooled_features = point_features.new_zeros( + (batch_size, boxes_num, num_sampled_points, 3 + feature_len)) + pooled_empty_flag = point_features.new_zeros( + (batch_size, boxes_num)).int() + + ext_module.roipoint_pool3d_forward(points.contiguous(), + pooled_boxes3d.contiguous(), + point_features.contiguous(), + pooled_features, pooled_empty_flag) + + return pooled_features, pooled_empty_flag + + @staticmethod + def backward(ctx, grad_out): + raise NotImplementedError diff --git a/annotator/uniformer/mmcv/ops/saconv.py b/annotator/uniformer/mmcv/ops/saconv.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ee3978e097fca422805db4e31ae481006d7971 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/saconv.py @@ -0,0 +1,145 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from annotator.uniformer.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init +from annotator.uniformer.mmcv.ops.deform_conv import deform_conv2d +from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version + + +@CONV_LAYERS.register_module(name='SAC') +class SAConv2d(ConvAWS2d): + """SAC (Switchable Atrous Convolution) + + This is an implementation of SAC in DetectoRS + (https://arxiv.org/pdf/2006.02334.pdf). + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel elements. + Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + use_deform: If ``True``, replace convolution with deformable + convolution. Default: ``False``. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + use_deform=False): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + self.use_deform = use_deform + self.switch = nn.Conv2d( + self.in_channels, 1, kernel_size=1, stride=stride, bias=True) + self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size())) + self.pre_context = nn.Conv2d( + self.in_channels, self.in_channels, kernel_size=1, bias=True) + self.post_context = nn.Conv2d( + self.out_channels, self.out_channels, kernel_size=1, bias=True) + if self.use_deform: + self.offset_s = nn.Conv2d( + self.in_channels, + 18, + kernel_size=3, + padding=1, + stride=stride, + bias=True) + self.offset_l = nn.Conv2d( + self.in_channels, + 18, + kernel_size=3, + padding=1, + stride=stride, + bias=True) + self.init_weights() + + def init_weights(self): + constant_init(self.switch, 0, bias=1) + self.weight_diff.data.zero_() + constant_init(self.pre_context, 0) + constant_init(self.post_context, 0) + if self.use_deform: + constant_init(self.offset_s, 0) + constant_init(self.offset_l, 0) + + def forward(self, x): + # pre-context + avg_x = F.adaptive_avg_pool2d(x, output_size=1) + avg_x = self.pre_context(avg_x) + avg_x = avg_x.expand_as(x) + x = x + avg_x + # switch + avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect') + avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0) + switch = self.switch(avg_x) + # sac + weight = self._get_weight(self.weight) + zero_bias = torch.zeros( + self.out_channels, device=weight.device, dtype=weight.dtype) + + if self.use_deform: + offset = self.offset_s(avg_x) + out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, + self.dilation, self.groups, 1) + else: + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.5.0')): + out_s = super().conv2d_forward(x, weight) + elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): + # bias is a required argument of _conv_forward in torch 1.8.0 + out_s = super()._conv_forward(x, weight, zero_bias) + else: + out_s = super()._conv_forward(x, weight) + ori_p = self.padding + ori_d = self.dilation + self.padding = tuple(3 * p for p in self.padding) + self.dilation = tuple(3 * d for d in self.dilation) + weight = weight + self.weight_diff + if self.use_deform: + offset = self.offset_l(avg_x) + out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, + self.dilation, self.groups, 1) + else: + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.5.0')): + out_l = super().conv2d_forward(x, weight) + elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): + # bias is a required argument of _conv_forward in torch 1.8.0 + out_l = super()._conv_forward(x, weight, zero_bias) + else: + out_l = super()._conv_forward(x, weight) + + out = switch * out_s + (1 - switch) * out_l + self.padding = ori_p + self.dilation = ori_d + # post-context + avg_x = F.adaptive_avg_pool2d(out, output_size=1) + avg_x = self.post_context(avg_x) + avg_x = avg_x.expand_as(out) + out = out + avg_x + return out diff --git a/annotator/uniformer/mmcv/ops/scatter_points.py b/annotator/uniformer/mmcv/ops/scatter_points.py new file mode 100644 index 0000000000000000000000000000000000000000..2b8aa4169e9f6ca4a6f845ce17d6d1e4db416bb8 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/scatter_points.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', + ['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward']) + + +class _DynamicScatter(Function): + + @staticmethod + def forward(ctx, feats, coors, reduce_type='max'): + """convert kitti points(N, >=3) to voxels. + + Args: + feats (torch.Tensor): [N, C]. Points features to be reduced + into voxels. + coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates + (specifically multi-dim voxel index) of each points. + reduce_type (str, optional): Reduce op. support 'max', 'sum' and + 'mean'. Default: 'max'. + + Returns: + voxel_feats (torch.Tensor): [M, C]. Reduced features, input + features that shares the same voxel coordinates are reduced to + one row. + voxel_coors (torch.Tensor): [M, ndim]. Voxel coordinates. + """ + results = ext_module.dynamic_point_to_voxel_forward( + feats, coors, reduce_type) + (voxel_feats, voxel_coors, point2voxel_map, + voxel_points_count) = results + ctx.reduce_type = reduce_type + ctx.save_for_backward(feats, voxel_feats, point2voxel_map, + voxel_points_count) + ctx.mark_non_differentiable(voxel_coors) + return voxel_feats, voxel_coors + + @staticmethod + def backward(ctx, grad_voxel_feats, grad_voxel_coors=None): + (feats, voxel_feats, point2voxel_map, + voxel_points_count) = ctx.saved_tensors + grad_feats = torch.zeros_like(feats) + # TODO: whether to use index put or use cuda_backward + # To use index put, need point to voxel index + ext_module.dynamic_point_to_voxel_backward( + grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats, + point2voxel_map, voxel_points_count, ctx.reduce_type) + return grad_feats, None, None + + +dynamic_scatter = _DynamicScatter.apply + + +class DynamicScatter(nn.Module): + """Scatters points into voxels, used in the voxel encoder with dynamic + voxelization. + + Note: + The CPU and GPU implementation get the same output, but have numerical + difference after summation and division (e.g., 5e-7). + + Args: + voxel_size (list): list [x, y, z] size of three dimension. + point_cloud_range (list): The coordinate range of points, [x_min, + y_min, z_min, x_max, y_max, z_max]. + average_points (bool): whether to use avg pooling to scatter points + into voxel. + """ + + def __init__(self, voxel_size, point_cloud_range, average_points: bool): + super().__init__() + + self.voxel_size = voxel_size + self.point_cloud_range = point_cloud_range + self.average_points = average_points + + def forward_single(self, points, coors): + """Scatters points into voxels. + + Args: + points (torch.Tensor): Points to be reduced into voxels. + coors (torch.Tensor): Corresponding voxel coordinates (specifically + multi-dim voxel index) of each points. + + Returns: + voxel_feats (torch.Tensor): Reduced features, input features that + shares the same voxel coordinates are reduced to one row. + voxel_coors (torch.Tensor): Voxel coordinates. + """ + reduce = 'mean' if self.average_points else 'max' + return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce) + + def forward(self, points, coors): + """Scatters points/features into voxels. + + Args: + points (torch.Tensor): Points to be reduced into voxels. + coors (torch.Tensor): Corresponding voxel coordinates (specifically + multi-dim voxel index) of each points. + + Returns: + voxel_feats (torch.Tensor): Reduced features, input features that + shares the same voxel coordinates are reduced to one row. + voxel_coors (torch.Tensor): Voxel coordinates. + """ + if coors.size(-1) == 3: + return self.forward_single(points, coors) + else: + batch_size = coors[-1, 0] + 1 + voxels, voxel_coors = [], [] + for i in range(batch_size): + inds = torch.where(coors[:, 0] == i) + voxel, voxel_coor = self.forward_single( + points[inds], coors[inds][:, 1:]) + coor_pad = nn.functional.pad( + voxel_coor, (1, 0), mode='constant', value=i) + voxel_coors.append(coor_pad) + voxels.append(voxel) + features = torch.cat(voxels, dim=0) + feature_coors = torch.cat(voxel_coors, dim=0) + + return features, feature_coors + + def __repr__(self): + s = self.__class__.__name__ + '(' + s += 'voxel_size=' + str(self.voxel_size) + s += ', point_cloud_range=' + str(self.point_cloud_range) + s += ', average_points=' + str(self.average_points) + s += ')' + return s diff --git a/annotator/uniformer/mmcv/ops/sync_bn.py b/annotator/uniformer/mmcv/ops/sync_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b016fcbe860989c56cd1040034bcfa60e146d2 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/sync_bn.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.module import Module +from torch.nn.parameter import Parameter + +from annotator.uniformer.mmcv.cnn import NORM_LAYERS +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output', + 'sync_bn_backward_param', 'sync_bn_backward_data' +]) + + +class SyncBatchNormFunction(Function): + + @staticmethod + def symbolic(g, input, running_mean, running_var, weight, bias, momentum, + eps, group, group_size, stats_mode): + return g.op( + 'mmcv::MMCVSyncBatchNorm', + input, + running_mean, + running_var, + weight, + bias, + momentum_f=momentum, + eps_f=eps, + group_i=group, + group_size_i=group_size, + stats_mode=stats_mode) + + @staticmethod + def forward(self, input, running_mean, running_var, weight, bias, momentum, + eps, group, group_size, stats_mode): + self.momentum = momentum + self.eps = eps + self.group = group + self.group_size = group_size + self.stats_mode = stats_mode + + assert isinstance( + input, (torch.HalfTensor, torch.FloatTensor, + torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \ + f'only support Half or Float Tensor, but {input.type()}' + output = torch.zeros_like(input) + input3d = input.flatten(start_dim=2) + output3d = output.view_as(input3d) + num_channels = input3d.size(1) + + # ensure mean/var/norm/std are initialized as zeros + # ``torch.empty()`` does not guarantee that + mean = torch.zeros( + num_channels, dtype=torch.float, device=input3d.device) + var = torch.zeros( + num_channels, dtype=torch.float, device=input3d.device) + norm = torch.zeros_like( + input3d, dtype=torch.float, device=input3d.device) + std = torch.zeros( + num_channels, dtype=torch.float, device=input3d.device) + + batch_size = input3d.size(0) + if batch_size > 0: + ext_module.sync_bn_forward_mean(input3d, mean) + batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype) + else: + # skip updating mean and leave it as zeros when the input is empty + batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype) + + # synchronize mean and the batch flag + vec = torch.cat([mean, batch_flag]) + if self.stats_mode == 'N': + vec *= batch_size + if self.group_size > 1: + dist.all_reduce(vec, group=self.group) + total_batch = vec[-1].detach() + mean = vec[:num_channels] + + if self.stats_mode == 'default': + mean = mean / self.group_size + elif self.stats_mode == 'N': + mean = mean / total_batch.clamp(min=1) + else: + raise NotImplementedError + + # leave var as zeros when the input is empty + if batch_size > 0: + ext_module.sync_bn_forward_var(input3d, mean, var) + + if self.stats_mode == 'N': + var *= batch_size + if self.group_size > 1: + dist.all_reduce(var, group=self.group) + + if self.stats_mode == 'default': + var /= self.group_size + elif self.stats_mode == 'N': + var /= total_batch.clamp(min=1) + else: + raise NotImplementedError + + # if the total batch size over all the ranks is zero, + # we should not update the statistics in the current batch + update_flag = total_batch.clamp(max=1) + momentum = update_flag * self.momentum + ext_module.sync_bn_forward_output( + input3d, + mean, + var, + weight, + bias, + running_mean, + running_var, + norm, + std, + output3d, + eps=self.eps, + momentum=momentum, + group_size=self.group_size) + self.save_for_backward(norm, std, weight) + return output + + @staticmethod + @once_differentiable + def backward(self, grad_output): + norm, std, weight = self.saved_tensors + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(weight) + grad_input = torch.zeros_like(grad_output) + grad_output3d = grad_output.flatten(start_dim=2) + grad_input3d = grad_input.view_as(grad_output3d) + + batch_size = grad_input3d.size(0) + if batch_size > 0: + ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight, + grad_bias) + + # all reduce + if self.group_size > 1: + dist.all_reduce(grad_weight, group=self.group) + dist.all_reduce(grad_bias, group=self.group) + grad_weight /= self.group_size + grad_bias /= self.group_size + + if batch_size > 0: + ext_module.sync_bn_backward_data(grad_output3d, weight, + grad_weight, grad_bias, norm, std, + grad_input3d) + + return grad_input, None, None, grad_weight, grad_bias, \ + None, None, None, None, None + + +@NORM_LAYERS.register_module(name='MMSyncBN') +class SyncBatchNorm(Module): + """Synchronized Batch Normalization. + + Args: + num_features (int): number of features/chennels in input tensor + eps (float, optional): a value added to the denominator for numerical + stability. Defaults to 1e-5. + momentum (float, optional): the value used for the running_mean and + running_var computation. Defaults to 0.1. + affine (bool, optional): whether to use learnable affine parameters. + Defaults to True. + track_running_stats (bool, optional): whether to track the running + mean and variance during training. When set to False, this + module does not track such statistics, and initializes statistics + buffers ``running_mean`` and ``running_var`` as ``None``. When + these buffers are ``None``, this module always uses batch + statistics in both training and eval modes. Defaults to True. + group (int, optional): synchronization of stats happen within + each process group individually. By default it is synchronization + across the whole world. Defaults to None. + stats_mode (str, optional): The statistical mode. Available options + includes ``'default'`` and ``'N'``. Defaults to 'default'. + When ``stats_mode=='default'``, it computes the overall statistics + using those from each worker with equal weight, i.e., the + statistics are synchronized and simply divied by ``group``. This + mode will produce inaccurate statistics when empty tensors occur. + When ``stats_mode=='N'``, it compute the overall statistics using + the total number of batches in each worker ignoring the number of + group, i.e., the statistics are synchronized and then divied by + the total batch ``N``. This mode is beneficial when empty tensors + occur during training, as it average the total mean by the real + number of batch. + """ + + def __init__(self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + group=None, + stats_mode='default'): + super(SyncBatchNorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + group = dist.group.WORLD if group is None else group + self.group = group + self.group_size = dist.get_world_size(group) + assert stats_mode in ['default', 'N'], \ + f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"' + self.stats_mode = stats_mode + if self.affine: + self.weight = Parameter(torch.Tensor(num_features)) + self.bias = Parameter(torch.Tensor(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + if self.track_running_stats: + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.register_buffer('num_batches_tracked', + torch.tensor(0, dtype=torch.long)) + else: + self.register_buffer('running_mean', None) + self.register_buffer('running_var', None) + self.register_buffer('num_batches_tracked', None) + self.reset_parameters() + + def reset_running_stats(self): + if self.track_running_stats: + self.running_mean.zero_() + self.running_var.fill_(1) + self.num_batches_tracked.zero_() + + def reset_parameters(self): + self.reset_running_stats() + if self.affine: + self.weight.data.uniform_() # pytorch use ones_() + self.bias.data.zero_() + + def forward(self, input): + if input.dim() < 2: + raise ValueError( + f'expected at least 2D input, got {input.dim()}D input') + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float( + self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + if self.training or not self.track_running_stats: + return SyncBatchNormFunction.apply( + input, self.running_mean, self.running_var, self.weight, + self.bias, exponential_average_factor, self.eps, self.group, + self.group_size, self.stats_mode) + else: + return F.batch_norm(input, self.running_mean, self.running_var, + self.weight, self.bias, False, + exponential_average_factor, self.eps) + + def __repr__(self): + s = self.__class__.__name__ + s += f'({self.num_features}, ' + s += f'eps={self.eps}, ' + s += f'momentum={self.momentum}, ' + s += f'affine={self.affine}, ' + s += f'track_running_stats={self.track_running_stats}, ' + s += f'group_size={self.group_size},' + s += f'stats_mode={self.stats_mode})' + return s diff --git a/annotator/uniformer/mmcv/ops/three_interpolate.py b/annotator/uniformer/mmcv/ops/three_interpolate.py new file mode 100644 index 0000000000000000000000000000000000000000..203f47f05d58087e034fb3cd8cd6a09233947b4a --- /dev/null +++ b/annotator/uniformer/mmcv/ops/three_interpolate.py @@ -0,0 +1,68 @@ +from typing import Tuple + +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['three_interpolate_forward', 'three_interpolate_backward']) + + +class ThreeInterpolate(Function): + """Performs weighted linear interpolation on 3 features. + + Please refer to `Paper of PointNet++ `_ + for more details. + """ + + @staticmethod + def forward(ctx, features: torch.Tensor, indices: torch.Tensor, + weight: torch.Tensor) -> torch.Tensor: + """ + Args: + features (Tensor): (B, C, M) Features descriptors to be + interpolated + indices (Tensor): (B, n, 3) index three nearest neighbors + of the target features in features + weight (Tensor): (B, n, 3) weights of interpolation + + Returns: + Tensor: (B, C, N) tensor of the interpolated features + """ + assert features.is_contiguous() + assert indices.is_contiguous() + assert weight.is_contiguous() + + B, c, m = features.size() + n = indices.size(1) + ctx.three_interpolate_for_backward = (indices, weight, m) + output = torch.cuda.FloatTensor(B, c, n) + + ext_module.three_interpolate_forward( + features, indices, weight, output, b=B, c=c, m=m, n=n) + return output + + @staticmethod + def backward( + ctx, grad_out: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + grad_out (Tensor): (B, C, N) tensor with gradients of outputs + + Returns: + Tensor: (B, C, M) tensor with gradients of features + """ + idx, weight, m = ctx.three_interpolate_for_backward + B, c, n = grad_out.size() + + grad_features = torch.cuda.FloatTensor(B, c, m).zero_() + grad_out_data = grad_out.data.contiguous() + + ext_module.three_interpolate_backward( + grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m) + return grad_features, None, None + + +three_interpolate = ThreeInterpolate.apply diff --git a/annotator/uniformer/mmcv/ops/three_nn.py b/annotator/uniformer/mmcv/ops/three_nn.py new file mode 100644 index 0000000000000000000000000000000000000000..2b01047a129989cd5545a0a86f23a487f4a13ce1 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/three_nn.py @@ -0,0 +1,51 @@ +from typing import Tuple + +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['three_nn_forward']) + + +class ThreeNN(Function): + """Find the top-3 nearest neighbors of the target set from the source set. + + Please refer to `Paper of PointNet++ `_ + for more details. + """ + + @staticmethod + def forward(ctx, target: torch.Tensor, + source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + target (Tensor): shape (B, N, 3), points set that needs to + find the nearest neighbors. + source (Tensor): shape (B, M, 3), points set that is used + to find the nearest neighbors of points in target set. + + Returns: + Tensor: shape (B, N, 3), L2 distance of each point in target + set to their corresponding nearest neighbors. + """ + target = target.contiguous() + source = source.contiguous() + + B, N, _ = target.size() + m = source.size(1) + dist2 = torch.cuda.FloatTensor(B, N, 3) + idx = torch.cuda.IntTensor(B, N, 3) + + ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m) + if torch.__version__ != 'parrots': + ctx.mark_non_differentiable(idx) + + return torch.sqrt(dist2), idx + + @staticmethod + def backward(ctx, a=None, b=None): + return None, None + + +three_nn = ThreeNN.apply diff --git a/annotator/uniformer/mmcv/ops/tin_shift.py b/annotator/uniformer/mmcv/ops/tin_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..472c9fcfe45a124e819b7ed5653e585f94a8811e --- /dev/null +++ b/annotator/uniformer/mmcv/ops/tin_shift.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Code reference from "Temporal Interlacing Network" +# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py +# Hao Shao, Shengju Qian, Yu Liu +# shaoh19@mails.tsinghua.edu.cn, sjqian@cse.cuhk.edu.hk, yuliu@ee.cuhk.edu.hk + +import torch +import torch.nn as nn +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', + ['tin_shift_forward', 'tin_shift_backward']) + + +class TINShiftFunction(Function): + + @staticmethod + def forward(ctx, input, shift): + C = input.size(2) + num_segments = shift.size(1) + if C // num_segments <= 0 or C % num_segments != 0: + raise ValueError('C should be a multiple of num_segments, ' + f'but got C={C} and num_segments={num_segments}.') + + ctx.save_for_backward(shift) + + out = torch.zeros_like(input) + ext_module.tin_shift_forward(input, shift, out) + + return out + + @staticmethod + def backward(ctx, grad_output): + + shift = ctx.saved_tensors[0] + data_grad_input = grad_output.new(*grad_output.size()).zero_() + shift_grad_input = shift.new(*shift.size()).zero_() + ext_module.tin_shift_backward(grad_output, shift, data_grad_input) + + return data_grad_input, shift_grad_input + + +tin_shift = TINShiftFunction.apply + + +class TINShift(nn.Module): + """Temporal Interlace Shift. + + Temporal Interlace shift is a differentiable temporal-wise frame shifting + which is proposed in "Temporal Interlacing Network" + + Please refer to https://arxiv.org/abs/2001.06499 for more details. + Code is modified from https://github.com/mit-han-lab/temporal-shift-module + """ + + def forward(self, input, shift): + """Perform temporal interlace shift. + + Args: + input (Tensor): Feature map with shape [N, num_segments, C, H * W]. + shift (Tensor): Shift tensor with shape [N, num_segments]. + + Returns: + Feature map after temporal interlace shift. + """ + return tin_shift(input, shift) diff --git a/annotator/uniformer/mmcv/ops/upfirdn2d.py b/annotator/uniformer/mmcv/ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..c8bb2c3c949eed38a6465ed369fa881538dca010 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/upfirdn2d.py @@ -0,0 +1,330 @@ +# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator +# Augmentation (ADA) +# ======================================================================= + +# 1. Definitions + +# "Licensor" means any person or entity that distributes its Work. + +# "Software" means the original work of authorship made available under +# this License. + +# "Work" means the Software and any additions to or derivative works of +# the Software that are made available under this License. + +# The terms "reproduce," "reproduction," "derivative works," and +# "distribution" have the meaning as provided under U.S. copyright law; +# provided, however, that for the purposes of this License, derivative +# works shall not include works that remain separable from, or merely +# link (or bind by name) to the interfaces of, the Work. + +# Works, including the Software, are "made available" under this License +# by including in or with the Work either (a) a copyright notice +# referencing the applicability of this License to the Work, or (b) a +# copy of this License. + +# 2. License Grants + +# 2.1 Copyright Grant. Subject to the terms and conditions of this +# License, each Licensor grants to you a perpetual, worldwide, +# non-exclusive, royalty-free, copyright license to reproduce, +# prepare derivative works of, publicly display, publicly perform, +# sublicense and distribute its Work and any resulting derivative +# works in any form. + +# 3. Limitations + +# 3.1 Redistribution. You may reproduce or distribute the Work only +# if (a) you do so under this License, (b) you include a complete +# copy of this License with your distribution, and (c) you retain +# without modification any copyright, patent, trademark, or +# attribution notices that are present in the Work. + +# 3.2 Derivative Works. You may specify that additional or different +# terms apply to the use, reproduction, and distribution of your +# derivative works of the Work ("Your Terms") only if (a) Your Terms +# provide that the use limitation in Section 3.3 applies to your +# derivative works, and (b) you identify the specific derivative +# works that are subject to Your Terms. Notwithstanding Your Terms, +# this License (including the redistribution requirements in Section +# 3.1) will continue to apply to the Work itself. + +# 3.3 Use Limitation. The Work and any derivative works thereof only +# may be used or intended for use non-commercially. Notwithstanding +# the foregoing, NVIDIA and its affiliates may use the Work and any +# derivative works commercially. As used herein, "non-commercially" +# means for research or evaluation purposes only. + +# 3.4 Patent Claims. If you bring or threaten to bring a patent claim +# against any Licensor (including any claim, cross-claim or +# counterclaim in a lawsuit) to enforce any patents that you allege +# are infringed by any Work, then your rights under this License from +# such Licensor (including the grant in Section 2.1) will terminate +# immediately. + +# 3.5 Trademarks. This License does not grant any rights to use any +# Licensor’s or its affiliates’ names, logos, or trademarks, except +# as necessary to reproduce the notices described in this License. + +# 3.6 Termination. If you violate any term of this License, then your +# rights under this License (including the grant in Section 2.1) will +# terminate immediately. + +# 4. Disclaimer of Warranty. + +# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +# THIS LICENSE. + +# 5. Limitation of Liability. + +# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGES. + +# ======================================================================= + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +from annotator.uniformer.mmcv.utils import to_2tuple +from ..utils import ext_loader + +upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d']) + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, + in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + up_x=down_x, + up_y=down_y, + down_x=up_x, + down_y=up_y, + pad_x0=g_pad_x0, + pad_x1=g_pad_x1, + pad_y0=g_pad_y0, + pad_y1=g_pad_y1) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], + in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], + ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + up_x=ctx.up_x, + up_y=ctx.up_y, + down_x=ctx.down_x, + down_y=ctx.down_y, + pad_x0=ctx.pad_x0, + pad_x1=ctx.pad_x1, + pad_y0=ctx.pad_y0, + pad_y1=ctx.pad_y1) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], + ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d( + input, + kernel, + up_x=up_x, + up_y=up_y, + down_x=down_x, + down_y=down_y, + pad_x0=pad_x0, + pad_x1=pad_x1, + pad_y0=pad_y0, + pad_y1=pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + """UpFRIDn for 2d features. + + UpFIRDn is short for upsample, apply FIR filter and downsample. More + details can be found in: + https://www.mathworks.com/help/signal/ref/upfirdn.html + + Args: + input (Tensor): Tensor with shape of (n, c, h, w). + kernel (Tensor): Filter kernel. + up (int | tuple[int], optional): Upsampling factor. If given a number, + we will use this factor for the both height and width side. + Defaults to 1. + down (int | tuple[int], optional): Downsampling factor. If given a + number, we will use this factor for the both height and width side. + Defaults to 1. + pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or + (x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0). + + Returns: + Tensor: Tensor after UpFIRDn. + """ + if input.device.type == 'cpu': + if len(pad) == 2: + pad = (pad[0], pad[1], pad[0], pad[1]) + + up = to_2tuple(up) + + down = to_2tuple(down) + + out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1], + pad[0], pad[1], pad[2], pad[3]) + else: + _up = to_2tuple(up) + + _down = to_2tuple(down) + + if len(pad) == 4: + _pad = pad + elif len(pad) == 2: + _pad = (pad[0], pad[1], pad[0], pad[1]) + + out = UpFirDn2d.apply(input, kernel, _up, _down, _pad) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, + pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, + [0, 0, + max(pad_x0, 0), + max(pad_x1, 0), + max(pad_y0, 0), + max(pad_y1, 0)]) + out = out[:, + max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/annotator/uniformer/mmcv/ops/voxelize.py b/annotator/uniformer/mmcv/ops/voxelize.py new file mode 100644 index 0000000000000000000000000000000000000000..ca3226a4fbcbfe58490fa2ea8e1c16b531214121 --- /dev/null +++ b/annotator/uniformer/mmcv/ops/voxelize.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.autograd import Function +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['dynamic_voxelize_forward', 'hard_voxelize_forward']) + + +class _Voxelization(Function): + + @staticmethod + def forward(ctx, + points, + voxel_size, + coors_range, + max_points=35, + max_voxels=20000): + """Convert kitti points(N, >=3) to voxels. + + Args: + points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points + and points[:, 3:] contain other information like reflectivity. + voxel_size (tuple or float): The size of voxel with the shape of + [3]. + coors_range (tuple or float): The coordinate range of voxel with + the shape of [6]. + max_points (int, optional): maximum points contained in a voxel. if + max_points=-1, it means using dynamic_voxelize. Default: 35. + max_voxels (int, optional): maximum voxels this function create. + for second, 20000 is a good choice. Users should shuffle points + before call this function because max_voxels may drop points. + Default: 20000. + + Returns: + voxels_out (torch.Tensor): Output voxels with the shape of [M, + max_points, ndim]. Only contain points and returned when + max_points != -1. + coors_out (torch.Tensor): Output coordinates with the shape of + [M, 3]. + num_points_per_voxel_out (torch.Tensor): Num points per voxel with + the shape of [M]. Only returned when max_points != -1. + """ + if max_points == -1 or max_voxels == -1: + coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int) + ext_module.dynamic_voxelize_forward(points, coors, voxel_size, + coors_range, 3) + return coors + else: + voxels = points.new_zeros( + size=(max_voxels, max_points, points.size(1))) + coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int) + num_points_per_voxel = points.new_zeros( + size=(max_voxels, ), dtype=torch.int) + voxel_num = ext_module.hard_voxelize_forward( + points, voxels, coors, num_points_per_voxel, voxel_size, + coors_range, max_points, max_voxels, 3) + # select the valid voxels + voxels_out = voxels[:voxel_num] + coors_out = coors[:voxel_num] + num_points_per_voxel_out = num_points_per_voxel[:voxel_num] + return voxels_out, coors_out, num_points_per_voxel_out + + +voxelization = _Voxelization.apply + + +class Voxelization(nn.Module): + """Convert kitti points(N, >=3) to voxels. + + Please refer to `PVCNN `_ for more + details. + + Args: + voxel_size (tuple or float): The size of voxel with the shape of [3]. + point_cloud_range (tuple or float): The coordinate range of voxel with + the shape of [6]. + max_num_points (int): maximum points contained in a voxel. if + max_points=-1, it means using dynamic_voxelize. + max_voxels (int, optional): maximum voxels this function create. + for second, 20000 is a good choice. Users should shuffle points + before call this function because max_voxels may drop points. + Default: 20000. + """ + + def __init__(self, + voxel_size, + point_cloud_range, + max_num_points, + max_voxels=20000): + super().__init__() + + self.voxel_size = voxel_size + self.point_cloud_range = point_cloud_range + self.max_num_points = max_num_points + if isinstance(max_voxels, tuple): + self.max_voxels = max_voxels + else: + self.max_voxels = _pair(max_voxels) + + point_cloud_range = torch.tensor( + point_cloud_range, dtype=torch.float32) + voxel_size = torch.tensor(voxel_size, dtype=torch.float32) + grid_size = (point_cloud_range[3:] - + point_cloud_range[:3]) / voxel_size + grid_size = torch.round(grid_size).long() + input_feat_shape = grid_size[:2] + self.grid_size = grid_size + # the origin shape is as [x-len, y-len, z-len] + # [w, h, d] -> [d, h, w] + self.pcd_shape = [*input_feat_shape, 1][::-1] + + def forward(self, input): + if self.training: + max_voxels = self.max_voxels[0] + else: + max_voxels = self.max_voxels[1] + + return voxelization(input, self.voxel_size, self.point_cloud_range, + self.max_num_points, max_voxels) + + def __repr__(self): + s = self.__class__.__name__ + '(' + s += 'voxel_size=' + str(self.voxel_size) + s += ', point_cloud_range=' + str(self.point_cloud_range) + s += ', max_num_points=' + str(self.max_num_points) + s += ', max_voxels=' + str(self.max_voxels) + s += ')' + return s diff --git a/annotator/uniformer/mmcv/parallel/__init__.py b/annotator/uniformer/mmcv/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed2c17ad357742e423beeaf4d35db03fe9af469 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .collate import collate +from .data_container import DataContainer +from .data_parallel import MMDataParallel +from .distributed import MMDistributedDataParallel +from .registry import MODULE_WRAPPERS +from .scatter_gather import scatter, scatter_kwargs +from .utils import is_module_wrapper + +__all__ = [ + 'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel', + 'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS' +] diff --git a/annotator/uniformer/mmcv/parallel/_functions.py b/annotator/uniformer/mmcv/parallel/_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5a8a44483ab991411d07122b22a1d027e4be8e --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/_functions.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parallel._functions import _get_stream + + +def scatter(input, devices, streams=None): + """Scatters tensor across multiple GPUs.""" + if streams is None: + streams = [None] * len(devices) + + if isinstance(input, list): + chunk_size = (len(input) - 1) // len(devices) + 1 + outputs = [ + scatter(input[i], [devices[i // chunk_size]], + [streams[i // chunk_size]]) for i in range(len(input)) + ] + return outputs + elif isinstance(input, torch.Tensor): + output = input.contiguous() + # TODO: copy to a pinned buffer first (if copying from CPU) + stream = streams[0] if output.numel() > 0 else None + if devices != [-1]: + with torch.cuda.device(devices[0]), torch.cuda.stream(stream): + output = output.cuda(devices[0], non_blocking=True) + else: + # unsqueeze the first dimension thus the tensor's shape is the + # same as those scattered with GPU. + output = output.unsqueeze(0) + return output + else: + raise Exception(f'Unknown type {type(input)}.') + + +def synchronize_stream(output, devices, streams): + if isinstance(output, list): + chunk_size = len(output) // len(devices) + for i in range(len(devices)): + for j in range(chunk_size): + synchronize_stream(output[i * chunk_size + j], [devices[i]], + [streams[i]]) + elif isinstance(output, torch.Tensor): + if output.numel() != 0: + with torch.cuda.device(devices[0]): + main_stream = torch.cuda.current_stream() + main_stream.wait_stream(streams[0]) + output.record_stream(main_stream) + else: + raise Exception(f'Unknown type {type(output)}.') + + +def get_input_device(input): + if isinstance(input, list): + for item in input: + input_device = get_input_device(item) + if input_device != -1: + return input_device + return -1 + elif isinstance(input, torch.Tensor): + return input.get_device() if input.is_cuda else -1 + else: + raise Exception(f'Unknown type {type(input)}.') + + +class Scatter: + + @staticmethod + def forward(target_gpus, input): + input_device = get_input_device(input) + streams = None + if input_device == -1 and target_gpus != [-1]: + # Perform CPU to GPU copies in a background stream + streams = [_get_stream(device) for device in target_gpus] + + outputs = scatter(input, target_gpus, streams) + # Synchronize with the copy stream + if streams is not None: + synchronize_stream(outputs, target_gpus, streams) + + return tuple(outputs) diff --git a/annotator/uniformer/mmcv/parallel/collate.py b/annotator/uniformer/mmcv/parallel/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..ad749197df21b0d74297548be5f66a696adebf7f --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/collate.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Mapping, Sequence + +import torch +import torch.nn.functional as F +from torch.utils.data.dataloader import default_collate + +from .data_container import DataContainer + + +def collate(batch, samples_per_gpu=1): + """Puts each data field into a tensor/DataContainer with outer dimension + batch size. + + Extend default_collate to add support for + :type:`~mmcv.parallel.DataContainer`. There are 3 cases. + + 1. cpu_only = True, e.g., meta data + 2. cpu_only = False, stack = True, e.g., images tensors + 3. cpu_only = False, stack = False, e.g., gt bboxes + """ + + if not isinstance(batch, Sequence): + raise TypeError(f'{batch.dtype} is not supported.') + + if isinstance(batch[0], DataContainer): + stacked = [] + if batch[0].cpu_only: + for i in range(0, len(batch), samples_per_gpu): + stacked.append( + [sample.data for sample in batch[i:i + samples_per_gpu]]) + return DataContainer( + stacked, batch[0].stack, batch[0].padding_value, cpu_only=True) + elif batch[0].stack: + for i in range(0, len(batch), samples_per_gpu): + assert isinstance(batch[i].data, torch.Tensor) + + if batch[i].pad_dims is not None: + ndim = batch[i].dim() + assert ndim > batch[i].pad_dims + max_shape = [0 for _ in range(batch[i].pad_dims)] + for dim in range(1, batch[i].pad_dims + 1): + max_shape[dim - 1] = batch[i].size(-dim) + for sample in batch[i:i + samples_per_gpu]: + for dim in range(0, ndim - batch[i].pad_dims): + assert batch[i].size(dim) == sample.size(dim) + for dim in range(1, batch[i].pad_dims + 1): + max_shape[dim - 1] = max(max_shape[dim - 1], + sample.size(-dim)) + padded_samples = [] + for sample in batch[i:i + samples_per_gpu]: + pad = [0 for _ in range(batch[i].pad_dims * 2)] + for dim in range(1, batch[i].pad_dims + 1): + pad[2 * dim - + 1] = max_shape[dim - 1] - sample.size(-dim) + padded_samples.append( + F.pad( + sample.data, pad, value=sample.padding_value)) + stacked.append(default_collate(padded_samples)) + elif batch[i].pad_dims is None: + stacked.append( + default_collate([ + sample.data + for sample in batch[i:i + samples_per_gpu] + ])) + else: + raise ValueError( + 'pad_dims should be either None or integers (1-3)') + + else: + for i in range(0, len(batch), samples_per_gpu): + stacked.append( + [sample.data for sample in batch[i:i + samples_per_gpu]]) + return DataContainer(stacked, batch[0].stack, batch[0].padding_value) + elif isinstance(batch[0], Sequence): + transposed = zip(*batch) + return [collate(samples, samples_per_gpu) for samples in transposed] + elif isinstance(batch[0], Mapping): + return { + key: collate([d[key] for d in batch], samples_per_gpu) + for key in batch[0] + } + else: + return default_collate(batch) diff --git a/annotator/uniformer/mmcv/parallel/data_container.py b/annotator/uniformer/mmcv/parallel/data_container.py new file mode 100644 index 0000000000000000000000000000000000000000..cedb0d32a51a1f575a622b38de2cee3ab4757821 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/data_container.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch + + +def assert_tensor_type(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not isinstance(args[0].data, torch.Tensor): + raise AttributeError( + f'{args[0].__class__.__name__} has no attribute ' + f'{func.__name__} for type {args[0].datatype}') + return func(*args, **kwargs) + + return wrapper + + +class DataContainer: + """A container for any type of objects. + + Typically tensors will be stacked in the collate function and sliced along + some dimension in the scatter function. This behavior has some limitations. + 1. All tensors have to be the same size. + 2. Types are limited (numpy array or Tensor). + + We design `DataContainer` and `MMDataParallel` to overcome these + limitations. The behavior can be either of the following. + + - copy to GPU, pad all tensors to the same size and stack them + - copy to GPU without stacking + - leave the objects as is and pass it to the model + - pad_dims specifies the number of last few dimensions to do padding + """ + + def __init__(self, + data, + stack=False, + padding_value=0, + cpu_only=False, + pad_dims=2): + self._data = data + self._cpu_only = cpu_only + self._stack = stack + self._padding_value = padding_value + assert pad_dims in [None, 1, 2, 3] + self._pad_dims = pad_dims + + def __repr__(self): + return f'{self.__class__.__name__}({repr(self.data)})' + + def __len__(self): + return len(self._data) + + @property + def data(self): + return self._data + + @property + def datatype(self): + if isinstance(self.data, torch.Tensor): + return self.data.type() + else: + return type(self.data) + + @property + def cpu_only(self): + return self._cpu_only + + @property + def stack(self): + return self._stack + + @property + def padding_value(self): + return self._padding_value + + @property + def pad_dims(self): + return self._pad_dims + + @assert_tensor_type + def size(self, *args, **kwargs): + return self.data.size(*args, **kwargs) + + @assert_tensor_type + def dim(self): + return self.data.dim() diff --git a/annotator/uniformer/mmcv/parallel/data_parallel.py b/annotator/uniformer/mmcv/parallel/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..79b5f69b654cf647dc7ae9174223781ab5c607d2 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/data_parallel.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import chain + +from torch.nn.parallel import DataParallel + +from .scatter_gather import scatter_kwargs + + +class MMDataParallel(DataParallel): + """The DataParallel module that supports DataContainer. + + MMDataParallel has two main differences with PyTorch DataParallel: + + - It supports a custom type :class:`DataContainer` which allows more + flexible control of input data during both GPU and CPU inference. + - It implement two more APIs ``train_step()`` and ``val_step()``. + + Args: + module (:class:`nn.Module`): Module to be encapsulated. + device_ids (list[int]): Device IDS of modules to be scattered to. + Defaults to None when GPU is not available. + output_device (str | int): Device ID for output. Defaults to None. + dim (int): Dimension used to scatter the data. Defaults to 0. + """ + + def __init__(self, *args, dim=0, **kwargs): + super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs) + self.dim = dim + + def forward(self, *inputs, **kwargs): + """Override the original forward function. + + The main difference lies in the CPU inference where the data in + :class:`DataContainers` will still be gathered. + """ + if not self.device_ids: + # We add the following line thus the module could gather and + # convert data containers as those in GPU inference + inputs, kwargs = self.scatter(inputs, kwargs, [-1]) + return self.module(*inputs[0], **kwargs[0]) + else: + return super().forward(*inputs, **kwargs) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def train_step(self, *inputs, **kwargs): + if not self.device_ids: + # We add the following line thus the module could gather and + # convert data containers as those in GPU inference + inputs, kwargs = self.scatter(inputs, kwargs, [-1]) + return self.module.train_step(*inputs[0], **kwargs[0]) + + assert len(self.device_ids) == 1, \ + ('MMDataParallel only supports single GPU training, if you need to' + ' train with multiple GPUs, please use MMDistributedDataParallel' + 'instead.') + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError( + 'module must have its parameters and buffers ' + f'on device {self.src_device_obj} (device_ids[0]) but ' + f'found one of them on device: {t.device}') + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + return self.module.train_step(*inputs[0], **kwargs[0]) + + def val_step(self, *inputs, **kwargs): + if not self.device_ids: + # We add the following line thus the module could gather and + # convert data containers as those in GPU inference + inputs, kwargs = self.scatter(inputs, kwargs, [-1]) + return self.module.val_step(*inputs[0], **kwargs[0]) + + assert len(self.device_ids) == 1, \ + ('MMDataParallel only supports single GPU training, if you need to' + ' train with multiple GPUs, please use MMDistributedDataParallel' + ' instead.') + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError( + 'module must have its parameters and buffers ' + f'on device {self.src_device_obj} (device_ids[0]) but ' + f'found one of them on device: {t.device}') + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + return self.module.val_step(*inputs[0], **kwargs[0]) diff --git a/annotator/uniformer/mmcv/parallel/distributed.py b/annotator/uniformer/mmcv/parallel/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4c27903db58a54d37ea1ed9ec0104098b486f2 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/distributed.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parallel.distributed import (DistributedDataParallel, + _find_tensors) + +from annotator.uniformer.mmcv import print_log +from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version +from .scatter_gather import scatter_kwargs + + +class MMDistributedDataParallel(DistributedDataParallel): + """The DDP module that supports DataContainer. + + MMDDP has two main differences with PyTorch DDP: + + - It supports a custom type :class:`DataContainer` which allows more + flexible control of input data. + - It implement two APIs ``train_step()`` and ``val_step()``. + """ + + def to_kwargs(self, inputs, kwargs, device_id): + # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8 + # to move all tensors to device_id + return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def train_step(self, *inputs, **kwargs): + """train_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.train_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + print_log( + 'Reducer buckets have been rebuilt in this iteration.', + logger='mmcv') + + if getattr(self, 'require_forward_param_sync', True): + self._sync_params() + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.train_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.train_step(*inputs, **kwargs) + + if torch.is_grad_enabled() and getattr( + self, 'require_backward_grad_sync', True): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output + + def val_step(self, *inputs, **kwargs): + """val_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.val_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + print_log( + 'Reducer buckets have been rebuilt in this iteration.', + logger='mmcv') + + if getattr(self, 'require_forward_param_sync', True): + self._sync_params() + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.val_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.val_step(*inputs, **kwargs) + + if torch.is_grad_enabled() and getattr( + self, 'require_backward_grad_sync', True): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output diff --git a/annotator/uniformer/mmcv/parallel/distributed_deprecated.py b/annotator/uniformer/mmcv/parallel/distributed_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..676937a2085d4da20fa87923041a200fca6214eb --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/distributed_deprecated.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.distributed as dist +import torch.nn as nn +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + +from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version +from .registry import MODULE_WRAPPERS +from .scatter_gather import scatter_kwargs + + +@MODULE_WRAPPERS.register_module() +class MMDistributedDataParallel(nn.Module): + + def __init__(self, + module, + dim=0, + broadcast_buffers=True, + bucket_cap_mb=25): + super(MMDistributedDataParallel, self).__init__() + self.module = module + self.dim = dim + self.broadcast_buffers = broadcast_buffers + + self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024 + self._sync_params() + + def _dist_broadcast_coalesced(self, tensors, buffer_size): + for tensors in _take_tensors(tensors, buffer_size): + flat_tensors = _flatten_dense_tensors(tensors) + dist.broadcast(flat_tensors, 0) + for tensor, synced in zip( + tensors, _unflatten_dense_tensors(flat_tensors, tensors)): + tensor.copy_(synced) + + def _sync_params(self): + module_states = list(self.module.state_dict().values()) + if len(module_states) > 0: + self._dist_broadcast_coalesced(module_states, + self.broadcast_bucket_size) + if self.broadcast_buffers: + if (TORCH_VERSION != 'parrots' + and digit_version(TORCH_VERSION) < digit_version('1.0')): + buffers = [b.data for b in self.module._all_buffers()] + else: + buffers = [b.data for b in self.module.buffers()] + if len(buffers) > 0: + self._dist_broadcast_coalesced(buffers, + self.broadcast_bucket_size) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def forward(self, *inputs, **kwargs): + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + return self.module(*inputs[0], **kwargs[0]) + + def train_step(self, *inputs, **kwargs): + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + output = self.module.train_step(*inputs[0], **kwargs[0]) + return output + + def val_step(self, *inputs, **kwargs): + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + output = self.module.val_step(*inputs[0], **kwargs[0]) + return output diff --git a/annotator/uniformer/mmcv/parallel/registry.py b/annotator/uniformer/mmcv/parallel/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..a204a07fba10e614223f090d1a57cf9c4d74d4a1 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/registry.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from annotator.uniformer.mmcv.utils import Registry + +MODULE_WRAPPERS = Registry('module wrapper') +MODULE_WRAPPERS.register_module(module=DataParallel) +MODULE_WRAPPERS.register_module(module=DistributedDataParallel) diff --git a/annotator/uniformer/mmcv/parallel/scatter_gather.py b/annotator/uniformer/mmcv/parallel/scatter_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..900ff88566f8f14830590459dc4fd16d4b382e47 --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/scatter_gather.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parallel._functions import Scatter as OrigScatter + +from ._functions import Scatter +from .data_container import DataContainer + + +def scatter(inputs, target_gpus, dim=0): + """Scatter inputs to target gpus. + + The only difference from original :func:`scatter` is to add support for + :type:`~mmcv.parallel.DataContainer`. + """ + + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + if target_gpus != [-1]: + return OrigScatter.apply(target_gpus, None, dim, obj) + else: + # for CPU inference we use self-implemented scatter + return Scatter.forward(target_gpus, obj) + if isinstance(obj, DataContainer): + if obj.cpu_only: + return obj.data + else: + return Scatter.forward(target_gpus, obj.data) + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + out = list(map(list, zip(*map(scatter_map, obj)))) + return out + if isinstance(obj, dict) and len(obj) > 0: + out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) + return out + return [obj for targets in target_gpus] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + return scatter_map(inputs) + finally: + scatter_map = None + + +def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): + """Scatter with support for kwargs dictionary.""" + inputs = scatter(inputs, target_gpus, dim) if inputs else [] + kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs diff --git a/annotator/uniformer/mmcv/parallel/utils.py b/annotator/uniformer/mmcv/parallel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f5712cb42c38a2e8563bf563efb6681383cab9b --- /dev/null +++ b/annotator/uniformer/mmcv/parallel/utils.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .registry import MODULE_WRAPPERS + + +def is_module_wrapper(module): + """Check if a module is a module wrapper. + + The following 3 modules in MMCV (and their subclasses) are regarded as + module wrappers: DataParallel, DistributedDataParallel, + MMDistributedDataParallel (the deprecated version). You may add you own + module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: True if the input module is a module wrapper. + """ + module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values()) + return isinstance(module, module_wrappers) diff --git a/annotator/uniformer/mmcv/runner/__init__.py b/annotator/uniformer/mmcv/runner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52e4b48d383a84a055dcd7f6236f6e8e58eab924 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/__init__.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_module import BaseModule, ModuleList, Sequential +from .base_runner import BaseRunner +from .builder import RUNNERS, build_runner +from .checkpoint import (CheckpointLoader, _load_checkpoint, + _load_checkpoint_with_prefix, load_checkpoint, + load_state_dict, save_checkpoint, weights_to_cpu) +from .default_constructor import DefaultRunnerConstructor +from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info, + init_dist, master_only) +from .epoch_based_runner import EpochBasedRunner, Runner +from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model +from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook, + DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook, + Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, + GradientCumulativeOptimizerHook, Hook, IterTimerHook, + LoggerHook, LrUpdaterHook, MlflowLoggerHook, + NeptuneLoggerHook, OptimizerHook, PaviLoggerHook, + SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook, + WandbLoggerHook) +from .iter_based_runner import IterBasedRunner, IterLoader +from .log_buffer import LogBuffer +from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS, + DefaultOptimizerConstructor, build_optimizer, + build_optimizer_constructor) +from .priority import Priority, get_priority +from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed + +__all__ = [ + 'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer', + 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', + 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', + 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', + 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook', + 'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict', + 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority', + 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict', + 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS', + 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer', + 'build_optimizer_constructor', 'IterLoader', 'set_random_seed', + 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook', + 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', + 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', + '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential', + 'ModuleList', 'GradientCumulativeOptimizerHook', + 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor' +] diff --git a/annotator/uniformer/mmcv/runner/base_module.py b/annotator/uniformer/mmcv/runner/base_module.py new file mode 100644 index 0000000000000000000000000000000000000000..617fad9bb89f10a9a0911d962dfb3bc8f3a3628c --- /dev/null +++ b/annotator/uniformer/mmcv/runner/base_module.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from abc import ABCMeta +from collections import defaultdict +from logging import FileHandler + +import torch.nn as nn + +from annotator.uniformer.mmcv.runner.dist_utils import master_only +from annotator.uniformer.mmcv.utils.logging import get_logger, logger_initialized, print_log + + +class BaseModule(nn.Module, metaclass=ABCMeta): + """Base module for all modules in openmmlab. + + ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional + functionality of parameter initialization. Compared with + ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes. + + - ``init_cfg``: the config to control the initialization. + - ``init_weights``: The function of parameter + initialization and recording initialization + information. + - ``_params_init_info``: Used to track the parameter + initialization information. This attribute only + exists during executing the ``init_weights``. + + Args: + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, init_cfg=None): + """Initialize BaseModule, inherited from `torch.nn.Module`""" + + # NOTE init_cfg can be defined in different levels, but init_cfg + # in low levels has a higher priority. + + super(BaseModule, self).__init__() + # define default value of init_cfg instead of hard code + # in init_weights() function + self._is_init = False + + self.init_cfg = copy.deepcopy(init_cfg) + + # Backward compatibility in derived classes + # if pretrained is not None: + # warnings.warn('DeprecationWarning: pretrained is a deprecated \ + # key, please consider using init_cfg') + # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + @property + def is_init(self): + return self._is_init + + def init_weights(self): + """Initialize the weights.""" + + is_top_level_module = False + # check if it is top-level module + if not hasattr(self, '_params_init_info'): + # The `_params_init_info` is used to record the initialization + # information of the parameters + # the key should be the obj:`nn.Parameter` of model and the value + # should be a dict containing + # - init_info (str): The string that describes the initialization. + # - tmp_mean_value (FloatTensor): The mean of the parameter, + # which indicates whether the parameter has been modified. + # this attribute would be deleted after all parameters + # is initialized. + self._params_init_info = defaultdict(dict) + is_top_level_module = True + + # Initialize the `_params_init_info`, + # When detecting the `tmp_mean_value` of + # the corresponding parameter is changed, update related + # initialization information + for name, param in self.named_parameters(): + self._params_init_info[param][ + 'init_info'] = f'The value is the same before and ' \ + f'after calling `init_weights` ' \ + f'of {self.__class__.__name__} ' + self._params_init_info[param][ + 'tmp_mean_value'] = param.data.mean() + + # pass `params_init_info` to all submodules + # All submodules share the same `params_init_info`, + # so it will be updated when parameters are + # modified at any level of the model. + for sub_module in self.modules(): + sub_module._params_init_info = self._params_init_info + + # Get the initialized logger, if not exist, + # create a logger named `mmcv` + logger_names = list(logger_initialized.keys()) + logger_name = logger_names[0] if logger_names else 'mmcv' + + from ..cnn import initialize + from ..cnn.utils.weight_init import update_init_info + module_name = self.__class__.__name__ + if not self._is_init: + if self.init_cfg: + print_log( + f'initialize {module_name} with init_cfg {self.init_cfg}', + logger=logger_name) + initialize(self, self.init_cfg) + if isinstance(self.init_cfg, dict): + # prevent the parameters of + # the pre-trained model + # from being overwritten by + # the `init_weights` + if self.init_cfg['type'] == 'Pretrained': + return + + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights() + # users may overload the `init_weights` + update_init_info( + m, + init_info=f'Initialized by ' + f'user-defined `init_weights`' + f' in {m.__class__.__name__} ') + + self._is_init = True + else: + warnings.warn(f'init_weights of {self.__class__.__name__} has ' + f'been called more than once.') + + if is_top_level_module: + self._dump_init_info(logger_name) + + for sub_module in self.modules(): + del sub_module._params_init_info + + @master_only + def _dump_init_info(self, logger_name): + """Dump the initialization information to a file named + `initialization.log.json` in workdir. + + Args: + logger_name (str): The name of logger. + """ + + logger = get_logger(logger_name) + + with_file_handler = False + # dump the information to the logger file if there is a `FileHandler` + for handler in logger.handlers: + if isinstance(handler, FileHandler): + handler.stream.write( + 'Name of parameter - Initialization information\n') + for name, param in self.named_parameters(): + handler.stream.write( + f'\n{name} - {param.shape}: ' + f"\n{self._params_init_info[param]['init_info']} \n") + handler.stream.flush() + with_file_handler = True + if not with_file_handler: + for name, param in self.named_parameters(): + print_log( + f'\n{name} - {param.shape}: ' + f"\n{self._params_init_info[param]['init_info']} \n ", + logger=logger_name) + + def __repr__(self): + s = super().__repr__() + if self.init_cfg: + s += f'\ninit_cfg={self.init_cfg}' + return s + + +class Sequential(BaseModule, nn.Sequential): + """Sequential module in openmmlab. + + Args: + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, *args, init_cfg=None): + BaseModule.__init__(self, init_cfg) + nn.Sequential.__init__(self, *args) + + +class ModuleList(BaseModule, nn.ModuleList): + """ModuleList in openmmlab. + + Args: + modules (iterable, optional): an iterable of modules to add. + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, modules=None, init_cfg=None): + BaseModule.__init__(self, init_cfg) + nn.ModuleList.__init__(self, modules) diff --git a/annotator/uniformer/mmcv/runner/base_runner.py b/annotator/uniformer/mmcv/runner/base_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..4928db0a73b56fe0218a4bf66ec4ffa082d31ccc --- /dev/null +++ b/annotator/uniformer/mmcv/runner/base_runner.py @@ -0,0 +1,542 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging +import os.path as osp +import warnings +from abc import ABCMeta, abstractmethod + +import torch +from torch.optim import Optimizer + +import annotator.uniformer.mmcv as mmcv +from ..parallel import is_module_wrapper +from .checkpoint import load_checkpoint +from .dist_utils import get_dist_info +from .hooks import HOOKS, Hook +from .log_buffer import LogBuffer +from .priority import Priority, get_priority +from .utils import get_time_str + + +class BaseRunner(metaclass=ABCMeta): + """The base class of Runner, a training helper for PyTorch. + + All subclasses should implement the following APIs: + + - ``run()`` + - ``train()`` + - ``val()`` + - ``save_checkpoint()`` + + Args: + model (:obj:`torch.nn.Module`): The model to be run. + batch_processor (callable): A callable method that process a data + batch. The interface of this method should be + `batch_processor(model, data, train_mode) -> dict` + optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an + optimizer (in most cases) or a dict of optimizers (in models that + requires more than one optimizer, e.g., GAN). + work_dir (str, optional): The working directory to save checkpoints + and logs. Defaults to None. + logger (:obj:`logging.Logger`): Logger used during training. + Defaults to None. (The default value is just for backward + compatibility) + meta (dict | None): A dict records some import information such as + environment info and seed, which will be logged in logger hook. + Defaults to None. + max_epochs (int, optional): Total training epochs. + max_iters (int, optional): Total training iterations. + """ + + def __init__(self, + model, + batch_processor=None, + optimizer=None, + work_dir=None, + logger=None, + meta=None, + max_iters=None, + max_epochs=None): + if batch_processor is not None: + if not callable(batch_processor): + raise TypeError('batch_processor must be callable, ' + f'but got {type(batch_processor)}') + warnings.warn('batch_processor is deprecated, please implement ' + 'train_step() and val_step() in the model instead.') + # raise an error is `batch_processor` is not None and + # `model.train_step()` exists. + if is_module_wrapper(model): + _model = model.module + else: + _model = model + if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'): + raise RuntimeError( + 'batch_processor and model.train_step()/model.val_step() ' + 'cannot be both available.') + else: + assert hasattr(model, 'train_step') + + # check the type of `optimizer` + if isinstance(optimizer, dict): + for name, optim in optimizer.items(): + if not isinstance(optim, Optimizer): + raise TypeError( + f'optimizer must be a dict of torch.optim.Optimizers, ' + f'but optimizer["{name}"] is a {type(optim)}') + elif not isinstance(optimizer, Optimizer) and optimizer is not None: + raise TypeError( + f'optimizer must be a torch.optim.Optimizer object ' + f'or dict or None, but got {type(optimizer)}') + + # check the type of `logger` + if not isinstance(logger, logging.Logger): + raise TypeError(f'logger must be a logging.Logger object, ' + f'but got {type(logger)}') + + # check the type of `meta` + if meta is not None and not isinstance(meta, dict): + raise TypeError( + f'meta must be a dict or None, but got {type(meta)}') + + self.model = model + self.batch_processor = batch_processor + self.optimizer = optimizer + self.logger = logger + self.meta = meta + # create work_dir + if mmcv.is_str(work_dir): + self.work_dir = osp.abspath(work_dir) + mmcv.mkdir_or_exist(self.work_dir) + elif work_dir is None: + self.work_dir = None + else: + raise TypeError('"work_dir" must be a str or None') + + # get model name from the model class + if hasattr(self.model, 'module'): + self._model_name = self.model.module.__class__.__name__ + else: + self._model_name = self.model.__class__.__name__ + + self._rank, self._world_size = get_dist_info() + self.timestamp = get_time_str() + self.mode = None + self._hooks = [] + self._epoch = 0 + self._iter = 0 + self._inner_iter = 0 + + if max_epochs is not None and max_iters is not None: + raise ValueError( + 'Only one of `max_epochs` or `max_iters` can be set.') + + self._max_epochs = max_epochs + self._max_iters = max_iters + # TODO: Redesign LogBuffer, it is not flexible and elegant enough + self.log_buffer = LogBuffer() + + @property + def model_name(self): + """str: Name of the model, usually the module class name.""" + return self._model_name + + @property + def rank(self): + """int: Rank of current process. (distributed training)""" + return self._rank + + @property + def world_size(self): + """int: Number of processes participating in the job. + (distributed training)""" + return self._world_size + + @property + def hooks(self): + """list[:obj:`Hook`]: A list of registered hooks.""" + return self._hooks + + @property + def epoch(self): + """int: Current epoch.""" + return self._epoch + + @property + def iter(self): + """int: Current iteration.""" + return self._iter + + @property + def inner_iter(self): + """int: Iteration in an epoch.""" + return self._inner_iter + + @property + def max_epochs(self): + """int: Maximum training epochs.""" + return self._max_epochs + + @property + def max_iters(self): + """int: Maximum training iterations.""" + return self._max_iters + + @abstractmethod + def train(self): + pass + + @abstractmethod + def val(self): + pass + + @abstractmethod + def run(self, data_loaders, workflow, **kwargs): + pass + + @abstractmethod + def save_checkpoint(self, + out_dir, + filename_tmpl, + save_optimizer=True, + meta=None, + create_symlink=True): + pass + + def current_lr(self): + """Get current learning rates. + + Returns: + list[float] | dict[str, list[float]]: Current learning rates of all + param groups. If the runner has a dict of optimizers, this + method will return a dict. + """ + if isinstance(self.optimizer, torch.optim.Optimizer): + lr = [group['lr'] for group in self.optimizer.param_groups] + elif isinstance(self.optimizer, dict): + lr = dict() + for name, optim in self.optimizer.items(): + lr[name] = [group['lr'] for group in optim.param_groups] + else: + raise RuntimeError( + 'lr is not applicable because optimizer does not exist.') + return lr + + def current_momentum(self): + """Get current momentums. + + Returns: + list[float] | dict[str, list[float]]: Current momentums of all + param groups. If the runner has a dict of optimizers, this + method will return a dict. + """ + + def _get_momentum(optimizer): + momentums = [] + for group in optimizer.param_groups: + if 'momentum' in group.keys(): + momentums.append(group['momentum']) + elif 'betas' in group.keys(): + momentums.append(group['betas'][0]) + else: + momentums.append(0) + return momentums + + if self.optimizer is None: + raise RuntimeError( + 'momentum is not applicable because optimizer does not exist.') + elif isinstance(self.optimizer, torch.optim.Optimizer): + momentums = _get_momentum(self.optimizer) + elif isinstance(self.optimizer, dict): + momentums = dict() + for name, optim in self.optimizer.items(): + momentums[name] = _get_momentum(optim) + return momentums + + def register_hook(self, hook, priority='NORMAL'): + """Register a hook into the hook list. + + The hook will be inserted into a priority queue, with the specified + priority (See :class:`Priority` for details of priorities). + For hooks with the same priority, they will be triggered in the same + order as they are registered. + + Args: + hook (:obj:`Hook`): The hook to be registered. + priority (int or str or :obj:`Priority`): Hook priority. + Lower value means higher priority. + """ + assert isinstance(hook, Hook) + if hasattr(hook, 'priority'): + raise ValueError('"priority" is a reserved attribute for hooks') + priority = get_priority(priority) + hook.priority = priority + # insert the hook to a sorted list + inserted = False + for i in range(len(self._hooks) - 1, -1, -1): + if priority >= self._hooks[i].priority: + self._hooks.insert(i + 1, hook) + inserted = True + break + if not inserted: + self._hooks.insert(0, hook) + + def register_hook_from_cfg(self, hook_cfg): + """Register a hook from its cfg. + + Args: + hook_cfg (dict): Hook config. It should have at least keys 'type' + and 'priority' indicating its type and priority. + + Notes: + The specific hook class to register should not use 'type' and + 'priority' arguments during initialization. + """ + hook_cfg = hook_cfg.copy() + priority = hook_cfg.pop('priority', 'NORMAL') + hook = mmcv.build_from_cfg(hook_cfg, HOOKS) + self.register_hook(hook, priority=priority) + + def call_hook(self, fn_name): + """Call all hooks. + + Args: + fn_name (str): The function name in each hook to be called, such as + "before_train_epoch". + """ + for hook in self._hooks: + getattr(hook, fn_name)(self) + + def get_hook_info(self): + # Get hooks info in each stage + stage_hook_map = {stage: [] for stage in Hook.stages} + for hook in self.hooks: + try: + priority = Priority(hook.priority).name + except ValueError: + priority = hook.priority + classname = hook.__class__.__name__ + hook_info = f'({priority:<12}) {classname:<35}' + for trigger_stage in hook.get_triggered_stages(): + stage_hook_map[trigger_stage].append(hook_info) + + stage_hook_infos = [] + for stage in Hook.stages: + hook_infos = stage_hook_map[stage] + if len(hook_infos) > 0: + info = f'{stage}:\n' + info += '\n'.join(hook_infos) + info += '\n -------------------- ' + stage_hook_infos.append(info) + return '\n'.join(stage_hook_infos) + + def load_checkpoint(self, + filename, + map_location='cpu', + strict=False, + revise_keys=[(r'^module.', '')]): + return load_checkpoint( + self.model, + filename, + map_location, + strict, + self.logger, + revise_keys=revise_keys) + + def resume(self, + checkpoint, + resume_optimizer=True, + map_location='default'): + if map_location == 'default': + if torch.cuda.is_available(): + device_id = torch.cuda.current_device() + checkpoint = self.load_checkpoint( + checkpoint, + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + checkpoint = self.load_checkpoint(checkpoint) + else: + checkpoint = self.load_checkpoint( + checkpoint, map_location=map_location) + + self._epoch = checkpoint['meta']['epoch'] + self._iter = checkpoint['meta']['iter'] + if self.meta is None: + self.meta = {} + self.meta.setdefault('hook_msgs', {}) + # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages + self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {})) + + # Re-calculate the number of iterations when resuming + # models with different number of GPUs + if 'config' in checkpoint['meta']: + config = mmcv.Config.fromstring( + checkpoint['meta']['config'], file_format='.py') + previous_gpu_ids = config.get('gpu_ids', None) + if previous_gpu_ids and len(previous_gpu_ids) > 0 and len( + previous_gpu_ids) != self.world_size: + self._iter = int(self._iter * len(previous_gpu_ids) / + self.world_size) + self.logger.info('the iteration number is changed due to ' + 'change of GPU number') + + # resume meta information meta + self.meta = checkpoint['meta'] + + if 'optimizer' in checkpoint and resume_optimizer: + if isinstance(self.optimizer, Optimizer): + self.optimizer.load_state_dict(checkpoint['optimizer']) + elif isinstance(self.optimizer, dict): + for k in self.optimizer.keys(): + self.optimizer[k].load_state_dict( + checkpoint['optimizer'][k]) + else: + raise TypeError( + 'Optimizer should be dict or torch.optim.Optimizer ' + f'but got {type(self.optimizer)}') + + self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) + + def register_lr_hook(self, lr_config): + if lr_config is None: + return + elif isinstance(lr_config, dict): + assert 'policy' in lr_config + policy_type = lr_config.pop('policy') + # If the type of policy is all in lower case, e.g., 'cyclic', + # then its first letter will be capitalized, e.g., to be 'Cyclic'. + # This is for the convenient usage of Lr updater. + # Since this is not applicable for ` + # CosineAnnealingLrUpdater`, + # the string will not be changed if it contains capital letters. + if policy_type == policy_type.lower(): + policy_type = policy_type.title() + hook_type = policy_type + 'LrUpdaterHook' + lr_config['type'] = hook_type + hook = mmcv.build_from_cfg(lr_config, HOOKS) + else: + hook = lr_config + self.register_hook(hook, priority='VERY_HIGH') + + def register_momentum_hook(self, momentum_config): + if momentum_config is None: + return + if isinstance(momentum_config, dict): + assert 'policy' in momentum_config + policy_type = momentum_config.pop('policy') + # If the type of policy is all in lower case, e.g., 'cyclic', + # then its first letter will be capitalized, e.g., to be 'Cyclic'. + # This is for the convenient usage of momentum updater. + # Since this is not applicable for + # `CosineAnnealingMomentumUpdater`, + # the string will not be changed if it contains capital letters. + if policy_type == policy_type.lower(): + policy_type = policy_type.title() + hook_type = policy_type + 'MomentumUpdaterHook' + momentum_config['type'] = hook_type + hook = mmcv.build_from_cfg(momentum_config, HOOKS) + else: + hook = momentum_config + self.register_hook(hook, priority='HIGH') + + def register_optimizer_hook(self, optimizer_config): + if optimizer_config is None: + return + if isinstance(optimizer_config, dict): + optimizer_config.setdefault('type', 'OptimizerHook') + hook = mmcv.build_from_cfg(optimizer_config, HOOKS) + else: + hook = optimizer_config + self.register_hook(hook, priority='ABOVE_NORMAL') + + def register_checkpoint_hook(self, checkpoint_config): + if checkpoint_config is None: + return + if isinstance(checkpoint_config, dict): + checkpoint_config.setdefault('type', 'CheckpointHook') + hook = mmcv.build_from_cfg(checkpoint_config, HOOKS) + else: + hook = checkpoint_config + self.register_hook(hook, priority='NORMAL') + + def register_logger_hooks(self, log_config): + if log_config is None: + return + log_interval = log_config['interval'] + for info in log_config['hooks']: + logger_hook = mmcv.build_from_cfg( + info, HOOKS, default_args=dict(interval=log_interval)) + self.register_hook(logger_hook, priority='VERY_LOW') + + def register_timer_hook(self, timer_config): + if timer_config is None: + return + if isinstance(timer_config, dict): + timer_config_ = copy.deepcopy(timer_config) + hook = mmcv.build_from_cfg(timer_config_, HOOKS) + else: + hook = timer_config + self.register_hook(hook, priority='LOW') + + def register_custom_hooks(self, custom_config): + if custom_config is None: + return + + if not isinstance(custom_config, list): + custom_config = [custom_config] + + for item in custom_config: + if isinstance(item, dict): + self.register_hook_from_cfg(item) + else: + self.register_hook(item, priority='NORMAL') + + def register_profiler_hook(self, profiler_config): + if profiler_config is None: + return + if isinstance(profiler_config, dict): + profiler_config.setdefault('type', 'ProfilerHook') + hook = mmcv.build_from_cfg(profiler_config, HOOKS) + else: + hook = profiler_config + self.register_hook(hook) + + def register_training_hooks(self, + lr_config, + optimizer_config=None, + checkpoint_config=None, + log_config=None, + momentum_config=None, + timer_config=dict(type='IterTimerHook'), + custom_hooks_config=None): + """Register default and custom hooks for training. + + Default and custom hooks include: + + +----------------------+-------------------------+ + | Hooks | Priority | + +======================+=========================+ + | LrUpdaterHook | VERY_HIGH (10) | + +----------------------+-------------------------+ + | MomentumUpdaterHook | HIGH (30) | + +----------------------+-------------------------+ + | OptimizerStepperHook | ABOVE_NORMAL (40) | + +----------------------+-------------------------+ + | CheckpointSaverHook | NORMAL (50) | + +----------------------+-------------------------+ + | IterTimerHook | LOW (70) | + +----------------------+-------------------------+ + | LoggerHook(s) | VERY_LOW (90) | + +----------------------+-------------------------+ + | CustomHook(s) | defaults to NORMAL (50) | + +----------------------+-------------------------+ + + If custom hooks have same priority with default hooks, custom hooks + will be triggered after default hooks. + """ + self.register_lr_hook(lr_config) + self.register_momentum_hook(momentum_config) + self.register_optimizer_hook(optimizer_config) + self.register_checkpoint_hook(checkpoint_config) + self.register_timer_hook(timer_config) + self.register_logger_hooks(log_config) + self.register_custom_hooks(custom_hooks_config) diff --git a/annotator/uniformer/mmcv/runner/builder.py b/annotator/uniformer/mmcv/runner/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..77c96ba0b2f30ead9da23f293c5dc84dd3e4a74f --- /dev/null +++ b/annotator/uniformer/mmcv/runner/builder.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +from ..utils import Registry + +RUNNERS = Registry('runner') +RUNNER_BUILDERS = Registry('runner builder') + + +def build_runner_constructor(cfg): + return RUNNER_BUILDERS.build(cfg) + + +def build_runner(cfg, default_args=None): + runner_cfg = copy.deepcopy(cfg) + constructor_type = runner_cfg.pop('constructor', + 'DefaultRunnerConstructor') + runner_constructor = build_runner_constructor( + dict( + type=constructor_type, + runner_cfg=runner_cfg, + default_args=default_args)) + runner = runner_constructor() + return runner diff --git a/annotator/uniformer/mmcv/runner/checkpoint.py b/annotator/uniformer/mmcv/runner/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..b29ca320679164432f446adad893e33fb2b4b29e --- /dev/null +++ b/annotator/uniformer/mmcv/runner/checkpoint.py @@ -0,0 +1,707 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import io +import os +import os.path as osp +import pkgutil +import re +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import torch +import torchvision +from torch.optim import Optimizer +from torch.utils import model_zoo + +import annotator.uniformer.mmcv as mmcv +from ..fileio import FileClient +from ..fileio import load as load_file +from ..parallel import is_module_wrapper +from ..utils import mkdir_or_exist +from .dist_utils import get_dist_info + +ENV_MMCV_HOME = 'MMCV_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_mmcv_home(): + mmcv_home = os.path.expanduser( + os.getenv( + ENV_MMCV_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) + + mkdir_or_exist(mmcv_home) + return mmcv_home + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +class CheckpointLoader: + """A general checkpoint loader to manage all schemes.""" + + _schemes = {} + + @classmethod + def _register_scheme(cls, prefixes, loader, force=False): + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if (prefix not in cls._schemes) or force: + cls._schemes[prefix] = loader + else: + raise KeyError( + f'{prefix} is already registered as a loader backend, ' + 'add "force=True" if you want to override it') + # sort, longer prefixes take priority + cls._schemes = OrderedDict( + sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) + + @classmethod + def register_scheme(cls, prefixes, loader=None, force=False): + """Register a loader to CheckpointLoader. + + This method can be used as a normal class method or a decorator. + + Args: + prefixes (str or list[str] or tuple[str]): + The prefix of the registered loader. + loader (function, optional): The loader function to be registered. + When this method is used as a decorator, loader is None. + Defaults to None. + force (bool, optional): Whether to override the loader + if the prefix has already been registered. Defaults to False. + """ + + if loader is not None: + cls._register_scheme(prefixes, loader, force=force) + return + + def _register(loader_cls): + cls._register_scheme(prefixes, loader_cls, force=force) + return loader_cls + + return _register + + @classmethod + def _get_checkpoint_loader(cls, path): + """Finds a loader that supports the given path. Falls back to the local + loader if no other loader is found. + + Args: + path (str): checkpoint path + + Returns: + loader (function): checkpoint loader + """ + + for p in cls._schemes: + if path.startswith(p): + return cls._schemes[p] + + @classmethod + def load_checkpoint(cls, filename, map_location=None, logger=None): + """load checkpoint through URL scheme path. + + Args: + filename (str): checkpoint file name with given prefix + map_location (str, optional): Same as :func:`torch.load`. + Default: None + logger (:mod:`logging.Logger`, optional): The logger for message. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint_loader = cls._get_checkpoint_loader(filename) + class_name = checkpoint_loader.__name__ + mmcv.print_log( + f'load checkpoint from {class_name[10:]} path: {filename}', logger) + return checkpoint_loader(filename, map_location) + + +@CheckpointLoader.register_scheme(prefixes='') +def load_from_local(filename, map_location): + """load checkpoint by local file path. + + Args: + filename (str): local checkpoint file path + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) +def load_from_http(filename, map_location=None, model_dir=None): + """load checkpoint through HTTP or HTTPS scheme path. In distributed + setting, this function only download checkpoint at local rank 0. + + Args: + filename (str): checkpoint file path with modelzoo or + torchvision prefix + map_location (str, optional): Same as :func:`torch.load`. + model_dir (string, optional): directory in which to save the object, + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + checkpoint = model_zoo.load_url( + filename, model_dir=model_dir, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url( + filename, model_dir=model_dir, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes='pavi://') +def load_from_pavi(filename, map_location=None): + """load checkpoint through the file path prefixed with pavi. In distributed + setting, this function download ckpt at all ranks to different temporary + directories. + + Args: + filename (str): checkpoint file path with pavi prefix + map_location (str, optional): Same as :func:`torch.load`. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + assert filename.startswith('pavi://'), \ + f'Expected filename startswith `pavi://`, but get {filename}' + model_path = filename[7:] + + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes='s3://') +def load_from_ceph(filename, map_location=None, backend='petrel'): + """load checkpoint through the file path prefixed with s3. In distributed + setting, this function download ckpt at all ranks to different temporary + directories. + + Args: + filename (str): checkpoint file path with s3 prefix + map_location (str, optional): Same as :func:`torch.load`. + backend (str, optional): The storage backend type. Options are 'ceph', + 'petrel'. Default: 'petrel'. + + .. warning:: + :class:`mmcv.fileio.file_client.CephBackend` will be deprecated, + please use :class:`mmcv.fileio.file_client.PetrelBackend` instead. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + allowed_backends = ['ceph', 'petrel'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + + if backend == 'ceph': + warnings.warn( + 'CephBackend will be deprecated, please use PetrelBackend instead') + + # CephClient and PetrelBackend have the same prefix 's3://' and the latter + # will be chosen as default. If PetrelBackend can not be instantiated + # successfully, the CephClient will be chosen. + try: + file_client = FileClient(backend=backend) + except ImportError: + allowed_backends.remove(backend) + file_client = FileClient(backend=allowed_backends[0]) + + with io.BytesIO(file_client.get(filename)) as buffer: + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) +def load_from_torchvision(filename, map_location=None): + """load checkpoint through the file path prefixed with modelzoo or + torchvision. + + Args: + filename (str): checkpoint file path with modelzoo or + torchvision prefix + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + model_urls = get_torchvision_models() + if filename.startswith('modelzoo://'): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead') + model_name = filename[11:] + else: + model_name = filename[14:] + return load_from_http(model_urls[model_name], map_location=map_location) + + +@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) +def load_from_openmmlab(filename, map_location=None): + """load checkpoint through the file path prefixed with open-mmlab or + openmmlab. + + Args: + filename (str): checkpoint file path with open-mmlab or + openmmlab prefix + map_location (str, optional): Same as :func:`torch.load`. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + model_urls = get_external_models() + prefix_str = 'open-mmlab://' + if filename.startswith(prefix_str): + model_name = filename[13:] + else: + model_name = filename[12:] + prefix_str = 'openmmlab://' + + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'{prefix_str}{model_name} is deprecated in favor ' + f'of {prefix_str}{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_from_http(model_url, map_location=map_location) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes='mmcls://') +def load_from_mmcls(filename, map_location=None): + """load checkpoint through the file path prefixed with mmcls. + + Args: + filename (str): checkpoint file path with mmcls prefix + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_from_http( + model_urls[model_name], map_location=map_location) + checkpoint = _process_mmcls_checkpoint(checkpoint) + return checkpoint + + +def _load_checkpoint(filename, map_location=None, logger=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str, optional): Same as :func:`torch.load`. + Default: None. + logger (:mod:`logging.Logger`, optional): The logger for error message. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + return CheckpointLoader.load_checkpoint(filename, map_location, logger) + + +def _load_checkpoint_with_prefix(prefix, filename, map_location=None): + """Load partial pretrained model with specific prefix. + + Args: + prefix (str): The prefix of sub-module. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint = _load_checkpoint(filename, map_location=map_location) + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if not prefix.endswith('.'): + prefix += '.' + prefix_len = len(prefix) + + state_dict = { + k[prefix_len:]: v + for k, v in state_dict.items() if k.startswith(prefix) + } + + assert state_dict, f'{prefix} is not in the pretrained model' + return state_dict + + +def load_checkpoint(model, + filename, + map_location=None, + strict=False, + logger=None, + revise_keys=[(r'^module\.', '')]): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + revise_keys (list): A list of customized keywords to modify the + state_dict in checkpoint. Each item is a (pattern, replacement) + pair of the regular expression operations. Default: strip + the prefix 'module.' by [(r'^module\\.', '')]. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location, logger) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + # strip prefix of state_dict + metadata = getattr(state_dict, '_metadata', OrderedDict()) + for p, r in revise_keys: + state_dict = OrderedDict( + {re.sub(p, r, k): v + for k, v in state_dict.items()}) + # Keep metadata in state_dict + state_dict._metadata = metadata + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + # Keep metadata in state_dict + state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict()) + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, + filename, + optimizer=None, + meta=None, + file_client_args=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + if file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" if filename starts with' + f'"pavi://", but got {file_client_args}') + try: + from pavi import modelcloud + from pavi import exception + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except exception.NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + file_client = FileClient.infer_client(file_client_args, filename) + with io.BytesIO() as f: + torch.save(checkpoint, f) + file_client.put(f.getvalue(), filename) diff --git a/annotator/uniformer/mmcv/runner/default_constructor.py b/annotator/uniformer/mmcv/runner/default_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1f5b44168768dfda3947393a63a6cf9cf50b41 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/default_constructor.py @@ -0,0 +1,44 @@ +from .builder import RUNNER_BUILDERS, RUNNERS + + +@RUNNER_BUILDERS.register_module() +class DefaultRunnerConstructor: + """Default constructor for runners. + + Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`. + For example, We can inject some new properties and functions for `Runner`. + + Example: + >>> from annotator.uniformer.mmcv.runner import RUNNER_BUILDERS, build_runner + >>> # Define a new RunnerReconstructor + >>> @RUNNER_BUILDERS.register_module() + >>> class MyRunnerConstructor: + ... def __init__(self, runner_cfg, default_args=None): + ... if not isinstance(runner_cfg, dict): + ... raise TypeError('runner_cfg should be a dict', + ... f'but got {type(runner_cfg)}') + ... self.runner_cfg = runner_cfg + ... self.default_args = default_args + ... + ... def __call__(self): + ... runner = RUNNERS.build(self.runner_cfg, + ... default_args=self.default_args) + ... # Add new properties for existing runner + ... runner.my_name = 'my_runner' + ... runner.my_function = lambda self: print(self.my_name) + ... ... + >>> # build your runner + >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40, + ... constructor='MyRunnerConstructor') + >>> runner = build_runner(runner_cfg) + """ + + def __init__(self, runner_cfg, default_args=None): + if not isinstance(runner_cfg, dict): + raise TypeError('runner_cfg should be a dict', + f'but got {type(runner_cfg)}') + self.runner_cfg = runner_cfg + self.default_args = default_args + + def __call__(self): + return RUNNERS.build(self.runner_cfg, default_args=self.default_args) diff --git a/annotator/uniformer/mmcv/runner/dist_utils.py b/annotator/uniformer/mmcv/runner/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a1ef3fda5ceeb31bf15a73779da1b1903ab0fe --- /dev/null +++ b/annotator/uniformer/mmcv/runner/dist_utils.py @@ -0,0 +1,164 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import os +import subprocess +from collections import OrderedDict + +import torch +import torch.multiprocessing as mp +from torch import distributed as dist +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'mpi': + _init_dist_mpi(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_mpi(backend, **kwargs): + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + # use MASTER_ADDR in the environment variable if it already exists + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def allreduce_params(params, coalesce=True, bucket_size_mb=-1): + """Allreduce parameters. + + Args: + params (list[torch.Parameters]): List of parameters or buffers of a + model. + coalesce (bool, optional): Whether allreduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + _, world_size = get_dist_info() + if world_size == 1: + return + params = [param.data for param in params] + if coalesce: + _allreduce_coalesced(params, world_size, bucket_size_mb) + else: + for tensor in params: + dist.all_reduce(tensor.div_(world_size)) + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + """Allreduce gradients. + + Args: + params (list[torch.Parameters]): List of parameters of a model + coalesce (bool, optional): Whether allreduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + grads = [ + param.grad.data for param in params + if param.requires_grad and param.grad is not None + ] + _, world_size = get_dist_info() + if world_size == 1: + return + if coalesce: + _allreduce_coalesced(grads, world_size, bucket_size_mb) + else: + for tensor in grads: + dist.all_reduce(tensor.div_(world_size)) + + +def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + dist.all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) diff --git a/annotator/uniformer/mmcv/runner/epoch_based_runner.py b/annotator/uniformer/mmcv/runner/epoch_based_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..766a9ce6afdf09cd11b1b15005f5132583011348 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/epoch_based_runner.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import platform +import shutil +import time +import warnings + +import torch + +import annotator.uniformer.mmcv as mmcv +from .base_runner import BaseRunner +from .builder import RUNNERS +from .checkpoint import save_checkpoint +from .utils import get_host_info + + +@RUNNERS.register_module() +class EpochBasedRunner(BaseRunner): + """Epoch-based Runner. + + This runner train models epoch by epoch. + """ + + def run_iter(self, data_batch, train_mode, **kwargs): + if self.batch_processor is not None: + outputs = self.batch_processor( + self.model, data_batch, train_mode=train_mode, **kwargs) + elif train_mode: + outputs = self.model.train_step(data_batch, self.optimizer, + **kwargs) + else: + outputs = self.model.val_step(data_batch, self.optimizer, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('"batch_processor()" or "model.train_step()"' + 'and "model.val_step()" must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs + + def train(self, data_loader, **kwargs): + self.model.train() + self.mode = 'train' + self.data_loader = data_loader + self._max_iters = self._max_epochs * len(self.data_loader) + self.call_hook('before_train_epoch') + time.sleep(2) # Prevent possible deadlock during epoch transition + for i, data_batch in enumerate(self.data_loader): + self._inner_iter = i + self.call_hook('before_train_iter') + self.run_iter(data_batch, train_mode=True, **kwargs) + self.call_hook('after_train_iter') + self._iter += 1 + + self.call_hook('after_train_epoch') + self._epoch += 1 + + @torch.no_grad() + def val(self, data_loader, **kwargs): + self.model.eval() + self.mode = 'val' + self.data_loader = data_loader + self.call_hook('before_val_epoch') + time.sleep(2) # Prevent possible deadlock during epoch transition + for i, data_batch in enumerate(self.data_loader): + self._inner_iter = i + self.call_hook('before_val_iter') + self.run_iter(data_batch, train_mode=False) + self.call_hook('after_val_iter') + + self.call_hook('after_val_epoch') + + def run(self, data_loaders, workflow, max_epochs=None, **kwargs): + """Start running. + + Args: + data_loaders (list[:obj:`DataLoader`]): Dataloaders for training + and validation. + workflow (list[tuple]): A list of (phase, epochs) to specify the + running order and epochs. E.g, [('train', 2), ('val', 1)] means + running 2 epochs for training and 1 epoch for validation, + iteratively. + """ + assert isinstance(data_loaders, list) + assert mmcv.is_list_of(workflow, tuple) + assert len(data_loaders) == len(workflow) + if max_epochs is not None: + warnings.warn( + 'setting max_epochs in run is deprecated, ' + 'please set max_epochs in runner_config', DeprecationWarning) + self._max_epochs = max_epochs + + assert self._max_epochs is not None, ( + 'max_epochs must be specified during instantiation') + + for i, flow in enumerate(workflow): + mode, epochs = flow + if mode == 'train': + self._max_iters = self._max_epochs * len(data_loaders[i]) + break + + work_dir = self.work_dir if self.work_dir is not None else 'NONE' + self.logger.info('Start running, host: %s, work_dir: %s', + get_host_info(), work_dir) + self.logger.info('Hooks will be executed in the following order:\n%s', + self.get_hook_info()) + self.logger.info('workflow: %s, max: %d epochs', workflow, + self._max_epochs) + self.call_hook('before_run') + + while self.epoch < self._max_epochs: + for i, flow in enumerate(workflow): + mode, epochs = flow + if isinstance(mode, str): # self.train() + if not hasattr(self, mode): + raise ValueError( + f'runner has no method named "{mode}" to run an ' + 'epoch') + epoch_runner = getattr(self, mode) + else: + raise TypeError( + 'mode in workflow must be a str, but got {}'.format( + type(mode))) + + for _ in range(epochs): + if mode == 'train' and self.epoch >= self._max_epochs: + break + epoch_runner(data_loaders[i], **kwargs) + + time.sleep(1) # wait for some hooks like loggers to finish + self.call_hook('after_run') + + def save_checkpoint(self, + out_dir, + filename_tmpl='epoch_{}.pth', + save_optimizer=True, + meta=None, + create_symlink=True): + """Save the checkpoint. + + Args: + out_dir (str): The directory that checkpoints are saved. + filename_tmpl (str, optional): The checkpoint filename template, + which contains a placeholder for the epoch number. + Defaults to 'epoch_{}.pth'. + save_optimizer (bool, optional): Whether to save the optimizer to + the checkpoint. Defaults to True. + meta (dict, optional): The meta information to be saved in the + checkpoint. Defaults to None. + create_symlink (bool, optional): Whether to create a symlink + "latest.pth" to point to the latest checkpoint. + Defaults to True. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError( + f'meta should be a dict or None, but got {type(meta)}') + if self.meta is not None: + meta.update(self.meta) + # Note: meta.update(self.meta) should be done before + # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise + # there will be problems with resumed checkpoints. + # More details in https://github.com/open-mmlab/mmcv/pull/1108 + meta.update(epoch=self.epoch + 1, iter=self.iter) + + filename = filename_tmpl.format(self.epoch + 1) + filepath = osp.join(out_dir, filename) + optimizer = self.optimizer if save_optimizer else None + save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) + # in some environments, `os.symlink` is not supported, you may need to + # set `create_symlink` to False + if create_symlink: + dst_file = osp.join(out_dir, 'latest.pth') + if platform.system() != 'Windows': + mmcv.symlink(filename, dst_file) + else: + shutil.copy(filepath, dst_file) + + +@RUNNERS.register_module() +class Runner(EpochBasedRunner): + """Deprecated name of EpochBasedRunner.""" + + def __init__(self, *args, **kwargs): + warnings.warn( + 'Runner was deprecated, please use EpochBasedRunner instead') + super().__init__(*args, **kwargs) diff --git a/annotator/uniformer/mmcv/runner/fp16_utils.py b/annotator/uniformer/mmcv/runner/fp16_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1981011d6859192e3e663e29d13500d56ba47f6c --- /dev/null +++ b/annotator/uniformer/mmcv/runner/fp16_utils.py @@ -0,0 +1,410 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import warnings +from collections import abc +from inspect import getfullargspec + +import numpy as np +import torch +import torch.nn as nn + +from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version +from .dist_utils import allreduce_grads as _allreduce_grads + +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported + # and used; otherwise, auto fp16 will adopt mmcv's implementation. + # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16 + # manually, so the behavior may not be consistent with real amp. + from torch.cuda.amp import autocast +except ImportError: + pass + + +def cast_tensor_type(inputs, src_type, dst_type): + """Recursively convert Tensor in inputs from src_type to dst_type. + + Args: + inputs: Inputs that to be casted. + src_type (torch.dtype): Source type.. + dst_type (torch.dtype): Destination type. + + Returns: + The same type with inputs, but all contained Tensors have been cast. + """ + if isinstance(inputs, nn.Module): + return inputs + elif isinstance(inputs, torch.Tensor): + return inputs.to(dst_type) + elif isinstance(inputs, str): + return inputs + elif isinstance(inputs, np.ndarray): + return inputs + elif isinstance(inputs, abc.Mapping): + return type(inputs)({ + k: cast_tensor_type(v, src_type, dst_type) + for k, v in inputs.items() + }) + elif isinstance(inputs, abc.Iterable): + return type(inputs)( + cast_tensor_type(item, src_type, dst_type) for item in inputs) + else: + return inputs + + +def auto_fp16(apply_to=None, out_fp32=False): + """Decorator to enable fp16 training automatically. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If inputs arguments are fp32 tensors, they will + be converted to fp16 automatically. Arguments other than fp32 tensors are + ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the + backend, otherwise, original mmcv implementation will be adopted. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp32 (bool): Whether to convert the output back to fp32. + + Example: + + >>> import torch.nn as nn + >>> class MyModule1(nn.Module): + >>> + >>> # Convert x and y to fp16 + >>> @auto_fp16() + >>> def forward(self, x, y): + >>> pass + + >>> import torch.nn as nn + >>> class MyModule2(nn.Module): + >>> + >>> # convert pred to fp16 + >>> @auto_fp16(apply_to=('pred', )) + >>> def do_something(self, pred, others): + >>> pass + """ + + def auto_fp16_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError('@auto_fp16 can only be used to decorate the ' + 'method of nn.Module') + if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): + return old_func(*args, **kwargs) + + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + # NOTE: default args are not taken into consideration + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.float, torch.half)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = {} + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.float, torch.half) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + if (TORCH_VERSION != 'parrots' and + digit_version(TORCH_VERSION) >= digit_version('1.6.0')): + with autocast(enabled=True): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp32: + output = cast_tensor_type(output, torch.half, torch.float) + return output + + return new_func + + return auto_fp16_wrapper + + +def force_fp32(apply_to=None, out_fp16=False): + """Decorator to convert input arguments to fp32 in force. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If there are some inputs that must be processed + in fp32 mode, then this decorator can handle it. If inputs arguments are + fp16 tensors, they will be converted to fp32 automatically. Arguments other + than fp16 tensors are ignored. If you are using PyTorch >= 1.6, + torch.cuda.amp is used as the backend, otherwise, original mmcv + implementation will be adopted. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp16 (bool): Whether to convert the output back to fp16. + + Example: + + >>> import torch.nn as nn + >>> class MyModule1(nn.Module): + >>> + >>> # Convert x and y to fp32 + >>> @force_fp32() + >>> def loss(self, x, y): + >>> pass + + >>> import torch.nn as nn + >>> class MyModule2(nn.Module): + >>> + >>> # convert pred to fp32 + >>> @force_fp32(apply_to=('pred', )) + >>> def post_process(self, pred, others): + >>> pass + """ + + def force_fp32_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError('@force_fp32 can only be used to decorate the ' + 'method of nn.Module') + if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): + return old_func(*args, **kwargs) + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.half, torch.float)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = dict() + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.half, torch.float) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + if (TORCH_VERSION != 'parrots' and + digit_version(TORCH_VERSION) >= digit_version('1.6.0')): + with autocast(enabled=False): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp16: + output = cast_tensor_type(output, torch.float, torch.half) + return output + + return new_func + + return force_fp32_wrapper + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + warnings.warning( + '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be ' + 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads') + _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb) + + +def wrap_fp16_model(model): + """Wrap the FP32 model to FP16. + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the + backend, otherwise, original mmcv implementation will be adopted. + + For PyTorch >= 1.6, this function will + 1. Set fp16 flag inside the model to True. + + Otherwise: + 1. Convert FP32 model to FP16. + 2. Remain some necessary layers to be FP32, e.g., normalization layers. + 3. Set `fp16_enabled` flag inside the model to True. + + Args: + model (nn.Module): Model in FP32. + """ + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.6.0')): + # convert model to fp16 + model.half() + # patch the normalization layers to make it work in fp32 mode + patch_norm_fp32(model) + # set `fp16_enabled` flag + for m in model.modules(): + if hasattr(m, 'fp16_enabled'): + m.fp16_enabled = True + + +def patch_norm_fp32(module): + """Recursively convert normalization layers from FP16 to FP32. + + Args: + module (nn.Module): The modules to be converted in FP16. + + Returns: + nn.Module: The converted module, the normalization layers have been + converted to FP32. + """ + if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): + module.float() + if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3': + module.forward = patch_forward_method(module.forward, torch.half, + torch.float) + for child in module.children(): + patch_norm_fp32(child) + return module + + +def patch_forward_method(func, src_type, dst_type, convert_output=True): + """Patch the forward method of a module. + + Args: + func (callable): The original forward method. + src_type (torch.dtype): Type of input arguments to be converted from. + dst_type (torch.dtype): Type of input arguments to be converted to. + convert_output (bool): Whether to convert the output back to src_type. + + Returns: + callable: The patched forward method. + """ + + def new_forward(*args, **kwargs): + output = func(*cast_tensor_type(args, src_type, dst_type), + **cast_tensor_type(kwargs, src_type, dst_type)) + if convert_output: + output = cast_tensor_type(output, dst_type, src_type) + return output + + return new_forward + + +class LossScaler: + """Class that manages loss scaling in mixed precision training which + supports both dynamic or static mode. + + The implementation refers to + https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py. + Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling. + It's important to understand how :class:`LossScaler` operates. + Loss scaling is designed to combat the problem of underflowing + gradients encountered at long times when training fp16 networks. + Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. + If overflowing gradients are encountered, :class:`FP16_Optimizer` then + skips the update step for this particular iteration/minibatch, + and :class:`LossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients + detected,:class:`LossScaler` increases the loss scale once more. + In this way :class:`LossScaler` attempts to "ride the edge" of always + using the highest loss scale possible without incurring overflow. + + Args: + init_scale (float): Initial loss scale value, default: 2**32. + scale_factor (float): Factor used when adjusting the loss scale. + Default: 2. + mode (str): Loss scaling mode. 'dynamic' or 'static' + scale_window (int): Number of consecutive iterations without an + overflow to wait before increasing the loss scale. Default: 1000. + """ + + def __init__(self, + init_scale=2**32, + mode='dynamic', + scale_factor=2., + scale_window=1000): + self.cur_scale = init_scale + self.cur_iter = 0 + assert mode in ('dynamic', + 'static'), 'mode can only be dynamic or static' + self.mode = mode + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + + def has_overflow(self, params): + """Check if params contain overflow.""" + if self.mode != 'dynamic': + return False + for p in params: + if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data): + return True + return False + + def _has_inf_or_nan(x): + """Check if params contain NaN.""" + try: + cpu_sum = float(x.float().sum()) + except RuntimeError as instance: + if 'value cannot be converted' not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') \ + or cpu_sum != cpu_sum: + return True + return False + + def update_scale(self, overflow): + """update the current loss scale value when overflow happens.""" + if self.mode != 'dynamic': + return + if overflow: + self.cur_scale = max(self.cur_scale / self.scale_factor, 1) + self.last_overflow_iter = self.cur_iter + else: + if (self.cur_iter - self.last_overflow_iter) % \ + self.scale_window == 0: + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + def state_dict(self): + """Returns the state of the scaler as a :class:`dict`.""" + return dict( + cur_scale=self.cur_scale, + cur_iter=self.cur_iter, + mode=self.mode, + last_overflow_iter=self.last_overflow_iter, + scale_factor=self.scale_factor, + scale_window=self.scale_window) + + def load_state_dict(self, state_dict): + """Loads the loss_scaler state dict. + + Args: + state_dict (dict): scaler state. + """ + self.cur_scale = state_dict['cur_scale'] + self.cur_iter = state_dict['cur_iter'] + self.mode = state_dict['mode'] + self.last_overflow_iter = state_dict['last_overflow_iter'] + self.scale_factor = state_dict['scale_factor'] + self.scale_window = state_dict['scale_window'] + + @property + def loss_scale(self): + return self.cur_scale diff --git a/annotator/uniformer/mmcv/runner/hooks/__init__.py b/annotator/uniformer/mmcv/runner/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..915af28cefab14a14c1188ed861161080fd138a3 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .checkpoint import CheckpointHook +from .closure import ClosureHook +from .ema import EMAHook +from .evaluation import DistEvalHook, EvalHook +from .hook import HOOKS, Hook +from .iter_timer import IterTimerHook +from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook, + NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook, + TextLoggerHook, WandbLoggerHook) +from .lr_updater import LrUpdaterHook +from .memory import EmptyCacheHook +from .momentum_updater import MomentumUpdaterHook +from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, + GradientCumulativeOptimizerHook, OptimizerHook) +from .profiler import ProfilerHook +from .sampler_seed import DistSamplerSeedHook +from .sync_buffer import SyncBuffersHook + +__all__ = [ + 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', + 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook', + 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', + 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', + 'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook', + 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook', + 'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook', + 'GradientCumulativeFp16OptimizerHook' +] diff --git a/annotator/uniformer/mmcv/runner/hooks/checkpoint.py b/annotator/uniformer/mmcv/runner/hooks/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..6af3fae43ac4b35532641a81eb13557edfc7dfba --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/checkpoint.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings + +from annotator.uniformer.mmcv.fileio import FileClient +from ..dist_utils import allreduce_params, master_only +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class CheckpointHook(Hook): + """Save checkpoints periodically. + + Args: + interval (int): The saving period. If ``by_epoch=True``, interval + indicates epochs, otherwise it indicates iterations. + Default: -1, which means "never". + by_epoch (bool): Saving checkpoints by epoch or by iteration. + Default: True. + save_optimizer (bool): Whether to save optimizer state_dict in the + checkpoint. It is usually used for resuming experiments. + Default: True. + out_dir (str, optional): The root directory to save checkpoints. If not + specified, ``runner.work_dir`` will be used by default. If + specified, the ``out_dir`` will be the concatenation of ``out_dir`` + and the last level directory of ``runner.work_dir``. + `Changed in version 1.3.16.` + max_keep_ckpts (int, optional): The maximum checkpoints to keep. + In some cases we want only the latest few checkpoints and would + like to delete old ones to save the disk space. + Default: -1, which means unlimited. + save_last (bool, optional): Whether to force the last checkpoint to be + saved regardless of interval. Default: True. + sync_buffer (bool, optional): Whether to synchronize buffers in + different gpus. Default: False. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` + + .. warning:: + Before v1.3.16, the ``out_dir`` argument indicates the path where the + checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the + root directory and the final path to save checkpoint is the + concatenation of ``out_dir`` and the last level directory of + ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A" + and the value of ``runner.work_dir`` is "/path/of/B", then the final + path will be "/path/of/A/B". + """ + + def __init__(self, + interval=-1, + by_epoch=True, + save_optimizer=True, + out_dir=None, + max_keep_ckpts=-1, + save_last=True, + sync_buffer=False, + file_client_args=None, + **kwargs): + self.interval = interval + self.by_epoch = by_epoch + self.save_optimizer = save_optimizer + self.out_dir = out_dir + self.max_keep_ckpts = max_keep_ckpts + self.save_last = save_last + self.args = kwargs + self.sync_buffer = sync_buffer + self.file_client_args = file_client_args + + def before_run(self, runner): + if not self.out_dir: + self.out_dir = runner.work_dir + + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + + # if `self.out_dir` is not equal to `runner.work_dir`, it means that + # `self.out_dir` is set so the final `self.out_dir` is the + # concatenation of `self.out_dir` and the last level directory of + # `runner.work_dir` + if self.out_dir != runner.work_dir: + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + + runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by ' + f'{self.file_client.name}.')) + + # disable the create_symlink option because some file backends do not + # allow to create a symlink + if 'create_symlink' in self.args: + if self.args[ + 'create_symlink'] and not self.file_client.allow_symlink: + self.args['create_symlink'] = False + warnings.warn( + ('create_symlink is set as True by the user but is changed' + 'to be False because creating symbolic link is not ' + f'allowed in {self.file_client.name}')) + else: + self.args['create_symlink'] = self.file_client.allow_symlink + + def after_train_epoch(self, runner): + if not self.by_epoch: + return + + # save checkpoint for following cases: + # 1. every ``self.interval`` epochs + # 2. reach the last epoch of training + if self.every_n_epochs( + runner, self.interval) or (self.save_last + and self.is_last_epoch(runner)): + runner.logger.info( + f'Saving checkpoint at {runner.epoch + 1} epochs') + if self.sync_buffer: + allreduce_params(runner.model.buffers()) + self._save_checkpoint(runner) + + @master_only + def _save_checkpoint(self, runner): + """Save the current checkpoint and delete unwanted checkpoint.""" + runner.save_checkpoint( + self.out_dir, save_optimizer=self.save_optimizer, **self.args) + if runner.meta is not None: + if self.by_epoch: + cur_ckpt_filename = self.args.get( + 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) + else: + cur_ckpt_filename = self.args.get( + 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) + runner.meta.setdefault('hook_msgs', dict()) + runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( + self.out_dir, cur_ckpt_filename) + # remove other checkpoints + if self.max_keep_ckpts > 0: + if self.by_epoch: + name = 'epoch_{}.pth' + current_ckpt = runner.epoch + 1 + else: + name = 'iter_{}.pth' + current_ckpt = runner.iter + 1 + redundant_ckpts = range( + current_ckpt - self.max_keep_ckpts * self.interval, 0, + -self.interval) + filename_tmpl = self.args.get('filename_tmpl', name) + for _step in redundant_ckpts: + ckpt_path = self.file_client.join_path( + self.out_dir, filename_tmpl.format(_step)) + if self.file_client.isfile(ckpt_path): + self.file_client.remove(ckpt_path) + else: + break + + def after_train_iter(self, runner): + if self.by_epoch: + return + + # save checkpoint for following cases: + # 1. every ``self.interval`` iterations + # 2. reach the last iteration of training + if self.every_n_iters( + runner, self.interval) or (self.save_last + and self.is_last_iter(runner)): + runner.logger.info( + f'Saving checkpoint at {runner.iter + 1} iterations') + if self.sync_buffer: + allreduce_params(runner.model.buffers()) + self._save_checkpoint(runner) diff --git a/annotator/uniformer/mmcv/runner/hooks/closure.py b/annotator/uniformer/mmcv/runner/hooks/closure.py new file mode 100644 index 0000000000000000000000000000000000000000..b955f81f425be4ac3e6bb3f4aac653887989e872 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/closure.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class ClosureHook(Hook): + + def __init__(self, fn_name, fn): + assert hasattr(self, fn_name) + assert callable(fn) + setattr(self, fn_name, fn) diff --git a/annotator/uniformer/mmcv/runner/hooks/ema.py b/annotator/uniformer/mmcv/runner/hooks/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..15c7e68088f019802a59e7ae41cc1fe0c7f28f96 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/ema.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...parallel import is_module_wrapper +from ..hooks.hook import HOOKS, Hook + + +@HOOKS.register_module() +class EMAHook(Hook): + r"""Exponential Moving Average Hook. + + Use Exponential Moving Average on all parameters of model in training + process. All parameters have a ema backup, which update by the formula + as below. EMAHook takes priority over EvalHook and CheckpointSaverHook. + + .. math:: + + \text{Xema\_{t+1}} = (1 - \text{momentum}) \times + \text{Xema\_{t}} + \text{momentum} \times X_t + + Args: + momentum (float): The momentum used for updating ema parameter. + Defaults to 0.0002. + interval (int): Update ema parameter every interval iteration. + Defaults to 1. + warm_up (int): During first warm_up steps, we may use smaller momentum + to update ema parameters more slowly. Defaults to 100. + resume_from (str): The checkpoint path. Defaults to None. + """ + + def __init__(self, + momentum=0.0002, + interval=1, + warm_up=100, + resume_from=None): + assert isinstance(interval, int) and interval > 0 + self.warm_up = warm_up + self.interval = interval + assert momentum > 0 and momentum < 1 + self.momentum = momentum**interval + self.checkpoint = resume_from + + def before_run(self, runner): + """To resume model with it's ema parameters more friendly. + + Register ema parameter as ``named_buffer`` to model + """ + model = runner.model + if is_module_wrapper(model): + model = model.module + self.param_ema_buffer = {} + self.model_parameters = dict(model.named_parameters(recurse=True)) + for name, value in self.model_parameters.items(): + # "." is not allowed in module's buffer name + buffer_name = f"ema_{name.replace('.', '_')}" + self.param_ema_buffer[name] = buffer_name + model.register_buffer(buffer_name, value.data.clone()) + self.model_buffers = dict(model.named_buffers(recurse=True)) + if self.checkpoint is not None: + runner.resume(self.checkpoint) + + def after_train_iter(self, runner): + """Update ema parameter every self.interval iterations.""" + curr_step = runner.iter + # We warm up the momentum considering the instability at beginning + momentum = min(self.momentum, + (1 + curr_step) / (self.warm_up + curr_step)) + if curr_step % self.interval != 0: + return + for name, parameter in self.model_parameters.items(): + buffer_name = self.param_ema_buffer[name] + buffer_parameter = self.model_buffers[buffer_name] + buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data) + + def after_train_epoch(self, runner): + """We load parameter values from ema backup to model before the + EvalHook.""" + self._swap_ema_parameters() + + def before_train_epoch(self, runner): + """We recover model's parameter from ema backup after last epoch's + EvalHook.""" + self._swap_ema_parameters() + + def _swap_ema_parameters(self): + """Swap the parameter of model with parameter in ema_buffer.""" + for name, value in self.model_parameters.items(): + temp = value.data.clone() + ema_buffer = self.model_buffers[self.param_ema_buffer[name]] + value.data.copy_(ema_buffer.data) + ema_buffer.data.copy_(temp) diff --git a/annotator/uniformer/mmcv/runner/hooks/evaluation.py b/annotator/uniformer/mmcv/runner/hooks/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..4d00999ce5665c53bded8de9e084943eee2d230d --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/evaluation.py @@ -0,0 +1,509 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from math import inf + +import torch.distributed as dist +from torch.nn.modules.batchnorm import _BatchNorm +from torch.utils.data import DataLoader + +from annotator.uniformer.mmcv.fileio import FileClient +from annotator.uniformer.mmcv.utils import is_seq_of +from .hook import Hook +from .logger import LoggerHook + + +class EvalHook(Hook): + """Non-Distributed evaluation hook. + + This hook will regularly perform evaluation in a given interval when + performing in non-distributed environment. + + Args: + dataloader (DataLoader): A PyTorch dataloader, whose dataset has + implemented ``evaluate`` function. + start (int | None, optional): Evaluation starting epoch. It enables + evaluation before the training starts if ``start`` <= the resuming + epoch. If None, whether to evaluate is merely decided by + ``interval``. Default: None. + interval (int): Evaluation interval. Default: 1. + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + Default: True. + save_best (str, optional): If a metric is specified, it would measure + the best checkpoint during evaluation. The information about best + checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep + best score value and best checkpoint path, which will be also + loaded when resume checkpoint. Options are the evaluation metrics + on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox + detection and instance segmentation. ``AR@100`` for proposal + recall. If ``save_best`` is ``auto``, the first key of the returned + ``OrderedDict`` result will be used. Default: None. + rule (str | None, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Default: None. + test_fn (callable, optional): test a model with samples from a + dataloader, and return the test results. If ``None``, the default + test function ``mmcv.engine.single_gpu_test`` will be used. + (default: ``None``) + greater_keys (List[str] | None, optional): Metric keys that will be + inferred by 'greater' comparison rule. If ``None``, + _default_greater_keys will be used. (default: ``None``) + less_keys (List[str] | None, optional): Metric keys that will be + inferred by 'less' comparison rule. If ``None``, _default_less_keys + will be used. (default: ``None``) + out_dir (str, optional): The root directory to save checkpoints. If not + specified, `runner.work_dir` will be used by default. If specified, + the `out_dir` will be the concatenation of `out_dir` and the last + level directory of `runner.work_dir`. + `New in version 1.3.16.` + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. + `New in version 1.3.16.` + **eval_kwargs: Evaluation arguments fed into the evaluate function of + the dataset. + + Notes: + If new arguments are added for EvalHook, tools/test.py, + tools/eval_metric.py may be affected. + """ + + # Since the key for determine greater or less is related to the downstream + # tasks, downstream repos may need to overwrite the following inner + # variable accordingly. + + rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} + init_value_map = {'greater': -inf, 'less': inf} + _default_greater_keys = [ + 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', + 'mAcc', 'aAcc' + ] + _default_less_keys = ['loss'] + + def __init__(self, + dataloader, + start=None, + interval=1, + by_epoch=True, + save_best=None, + rule=None, + test_fn=None, + greater_keys=None, + less_keys=None, + out_dir=None, + file_client_args=None, + **eval_kwargs): + if not isinstance(dataloader, DataLoader): + raise TypeError(f'dataloader must be a pytorch DataLoader, ' + f'but got {type(dataloader)}') + + if interval <= 0: + raise ValueError(f'interval must be a positive number, ' + f'but got {interval}') + + assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean' + + if start is not None and start < 0: + raise ValueError(f'The evaluation start epoch {start} is smaller ' + f'than 0') + + self.dataloader = dataloader + self.interval = interval + self.start = start + self.by_epoch = by_epoch + + assert isinstance(save_best, str) or save_best is None, \ + '""save_best"" should be a str or None ' \ + f'rather than {type(save_best)}' + self.save_best = save_best + self.eval_kwargs = eval_kwargs + self.initial_flag = True + + if test_fn is None: + from annotator.uniformer.mmcv.engine import single_gpu_test + self.test_fn = single_gpu_test + else: + self.test_fn = test_fn + + if greater_keys is None: + self.greater_keys = self._default_greater_keys + else: + if not isinstance(greater_keys, (list, tuple)): + greater_keys = (greater_keys, ) + assert is_seq_of(greater_keys, str) + self.greater_keys = greater_keys + + if less_keys is None: + self.less_keys = self._default_less_keys + else: + if not isinstance(less_keys, (list, tuple)): + less_keys = (less_keys, ) + assert is_seq_of(less_keys, str) + self.less_keys = less_keys + + if self.save_best is not None: + self.best_ckpt_path = None + self._init_rule(rule, self.save_best) + + self.out_dir = out_dir + self.file_client_args = file_client_args + + def _init_rule(self, rule, key_indicator): + """Initialize rule, key_indicator, comparison_func, and best score. + + Here is the rule to determine which rule is used for key indicator + when the rule is not specific (note that the key indicator matching + is case-insensitive): + 1. If the key indicator is in ``self.greater_keys``, the rule will be + specified as 'greater'. + 2. Or if the key indicator is in ``self.less_keys``, the rule will be + specified as 'less'. + 3. Or if the key indicator is equal to the substring in any one item + in ``self.greater_keys``, the rule will be specified as 'greater'. + 4. Or if the key indicator is equal to the substring in any one item + in ``self.less_keys``, the rule will be specified as 'less'. + + Args: + rule (str | None): Comparison rule for best score. + key_indicator (str | None): Key indicator to determine the + comparison rule. + """ + if rule not in self.rule_map and rule is not None: + raise KeyError(f'rule must be greater, less or None, ' + f'but got {rule}.') + + if rule is None: + if key_indicator != 'auto': + # `_lc` here means we use the lower case of keys for + # case-insensitive matching + key_indicator_lc = key_indicator.lower() + greater_keys = [key.lower() for key in self.greater_keys] + less_keys = [key.lower() for key in self.less_keys] + + if key_indicator_lc in greater_keys: + rule = 'greater' + elif key_indicator_lc in less_keys: + rule = 'less' + elif any(key in key_indicator_lc for key in greater_keys): + rule = 'greater' + elif any(key in key_indicator_lc for key in less_keys): + rule = 'less' + else: + raise ValueError(f'Cannot infer the rule for key ' + f'{key_indicator}, thus a specific rule ' + f'must be specified.') + self.rule = rule + self.key_indicator = key_indicator + if self.rule is not None: + self.compare_func = self.rule_map[self.rule] + + def before_run(self, runner): + if not self.out_dir: + self.out_dir = runner.work_dir + + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + + # if `self.out_dir` is not equal to `runner.work_dir`, it means that + # `self.out_dir` is set so the final `self.out_dir` is the + # concatenation of `self.out_dir` and the last level directory of + # `runner.work_dir` + if self.out_dir != runner.work_dir: + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + runner.logger.info( + (f'The best checkpoint will be saved to {self.out_dir} by ' + f'{self.file_client.name}')) + + if self.save_best is not None: + if runner.meta is None: + warnings.warn('runner.meta is None. Creating an empty one.') + runner.meta = dict() + runner.meta.setdefault('hook_msgs', dict()) + self.best_ckpt_path = runner.meta['hook_msgs'].get( + 'best_ckpt', None) + + def before_train_iter(self, runner): + """Evaluate the model only at the start of training by iteration.""" + if self.by_epoch or not self.initial_flag: + return + if self.start is not None and runner.iter >= self.start: + self.after_train_iter(runner) + self.initial_flag = False + + def before_train_epoch(self, runner): + """Evaluate the model only at the start of training by epoch.""" + if not (self.by_epoch and self.initial_flag): + return + if self.start is not None and runner.epoch >= self.start: + self.after_train_epoch(runner) + self.initial_flag = False + + def after_train_iter(self, runner): + """Called after every training iter to evaluate the results.""" + if not self.by_epoch and self._should_evaluate(runner): + # Because the priority of EvalHook is higher than LoggerHook, the + # training log and the evaluating log are mixed. Therefore, + # we need to dump the training log and clear it before evaluating + # log is generated. In addition, this problem will only appear in + # `IterBasedRunner` whose `self.by_epoch` is False, because + # `EpochBasedRunner` whose `self.by_epoch` is True calls + # `_do_evaluate` in `after_train_epoch` stage, and at this stage + # the training log has been printed, so it will not cause any + # problem. more details at + # https://github.com/open-mmlab/mmsegmentation/issues/694 + for hook in runner._hooks: + if isinstance(hook, LoggerHook): + hook.after_train_iter(runner) + runner.log_buffer.clear() + + self._do_evaluate(runner) + + def after_train_epoch(self, runner): + """Called after every training epoch to evaluate the results.""" + if self.by_epoch and self._should_evaluate(runner): + self._do_evaluate(runner) + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + results = self.test_fn(runner.model, self.dataloader) + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + # the key_score may be `None` so it needs to skip the action to save + # the best checkpoint + if self.save_best and key_score: + self._save_ckpt(runner, key_score) + + def _should_evaluate(self, runner): + """Judge whether to perform evaluation. + + Here is the rule to judge whether to perform evaluation: + 1. It will not perform evaluation during the epoch/iteration interval, + which is determined by ``self.interval``. + 2. It will not perform evaluation if the start time is larger than + current time. + 3. It will not perform evaluation when current time is larger than + the start time but during epoch/iteration interval. + + Returns: + bool: The flag indicating whether to perform evaluation. + """ + if self.by_epoch: + current = runner.epoch + check_time = self.every_n_epochs + else: + current = runner.iter + check_time = self.every_n_iters + + if self.start is None: + if not check_time(runner, self.interval): + # No evaluation during the interval. + return False + elif (current + 1) < self.start: + # No evaluation if start is larger than the current time. + return False + else: + # Evaluation only at epochs/iters 3, 5, 7... + # if start==3 and interval==2 + if (current + 1 - self.start) % self.interval: + return False + return True + + def _save_ckpt(self, runner, key_score): + """Save the best checkpoint. + + It will compare the score according to the compare function, write + related information (best score, best checkpoint path) and save the + best checkpoint into ``work_dir``. + """ + if self.by_epoch: + current = f'epoch_{runner.epoch + 1}' + cur_type, cur_time = 'epoch', runner.epoch + 1 + else: + current = f'iter_{runner.iter + 1}' + cur_type, cur_time = 'iter', runner.iter + 1 + + best_score = runner.meta['hook_msgs'].get( + 'best_score', self.init_value_map[self.rule]) + if self.compare_func(key_score, best_score): + best_score = key_score + runner.meta['hook_msgs']['best_score'] = best_score + + if self.best_ckpt_path and self.file_client.isfile( + self.best_ckpt_path): + self.file_client.remove(self.best_ckpt_path) + runner.logger.info( + (f'The previous best checkpoint {self.best_ckpt_path} was ' + 'removed')) + + best_ckpt_name = f'best_{self.key_indicator}_{current}.pth' + self.best_ckpt_path = self.file_client.join_path( + self.out_dir, best_ckpt_name) + runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path + + runner.save_checkpoint( + self.out_dir, best_ckpt_name, create_symlink=False) + runner.logger.info( + f'Now best checkpoint is saved as {best_ckpt_name}.') + runner.logger.info( + f'Best {self.key_indicator} is {best_score:0.4f} ' + f'at {cur_time} {cur_type}.') + + def evaluate(self, runner, results): + """Evaluate the results. + + Args: + runner (:obj:`mmcv.Runner`): The underlined training runner. + results (list): Output results. + """ + eval_res = self.dataloader.dataset.evaluate( + results, logger=runner.logger, **self.eval_kwargs) + + for name, val in eval_res.items(): + runner.log_buffer.output[name] = val + runner.log_buffer.ready = True + + if self.save_best is not None: + # If the performance of model is pool, the `eval_res` may be an + # empty dict and it will raise exception when `self.save_best` is + # not None. More details at + # https://github.com/open-mmlab/mmdetection/issues/6265. + if not eval_res: + warnings.warn( + 'Since `eval_res` is an empty dict, the behavior to save ' + 'the best checkpoint will be skipped in this evaluation.') + return None + + if self.key_indicator == 'auto': + # infer from eval_results + self._init_rule(self.rule, list(eval_res.keys())[0]) + return eval_res[self.key_indicator] + + return None + + +class DistEvalHook(EvalHook): + """Distributed evaluation hook. + + This hook will regularly perform evaluation in a given interval when + performing in distributed environment. + + Args: + dataloader (DataLoader): A PyTorch dataloader, whose dataset has + implemented ``evaluate`` function. + start (int | None, optional): Evaluation starting epoch. It enables + evaluation before the training starts if ``start`` <= the resuming + epoch. If None, whether to evaluate is merely decided by + ``interval``. Default: None. + interval (int): Evaluation interval. Default: 1. + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + default: True. + save_best (str, optional): If a metric is specified, it would measure + the best checkpoint during evaluation. The information about best + checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep + best score value and best checkpoint path, which will be also + loaded when resume checkpoint. Options are the evaluation metrics + on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox + detection and instance segmentation. ``AR@100`` for proposal + recall. If ``save_best`` is ``auto``, the first key of the returned + ``OrderedDict`` result will be used. Default: None. + rule (str | None, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Default: None. + test_fn (callable, optional): test a model with samples from a + dataloader in a multi-gpu manner, and return the test results. If + ``None``, the default test function ``mmcv.engine.multi_gpu_test`` + will be used. (default: ``None``) + tmpdir (str | None): Temporary directory to save the results of all + processes. Default: None. + gpu_collect (bool): Whether to use gpu or cpu to collect results. + Default: False. + broadcast_bn_buffer (bool): Whether to broadcast the + buffer(running_mean and running_var) of rank 0 to other rank + before evaluation. Default: True. + out_dir (str, optional): The root directory to save checkpoints. If not + specified, `runner.work_dir` will be used by default. If specified, + the `out_dir` will be the concatenation of `out_dir` and the last + level directory of `runner.work_dir`. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. Default: None. + **eval_kwargs: Evaluation arguments fed into the evaluate function of + the dataset. + """ + + def __init__(self, + dataloader, + start=None, + interval=1, + by_epoch=True, + save_best=None, + rule=None, + test_fn=None, + greater_keys=None, + less_keys=None, + broadcast_bn_buffer=True, + tmpdir=None, + gpu_collect=False, + out_dir=None, + file_client_args=None, + **eval_kwargs): + + if test_fn is None: + from annotator.uniformer.mmcv.engine import multi_gpu_test + test_fn = multi_gpu_test + + super().__init__( + dataloader, + start=start, + interval=interval, + by_epoch=by_epoch, + save_best=save_best, + rule=rule, + test_fn=test_fn, + greater_keys=greater_keys, + less_keys=less_keys, + out_dir=out_dir, + file_client_args=file_client_args, + **eval_kwargs) + + self.broadcast_bn_buffer = broadcast_bn_buffer + self.tmpdir = tmpdir + self.gpu_collect = gpu_collect + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, '.eval_hook') + + results = self.test_fn( + runner.model, + self.dataloader, + tmpdir=tmpdir, + gpu_collect=self.gpu_collect) + if runner.rank == 0: + print('\n') + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + # the key_score may be `None` so it needs to skip the action to + # save the best checkpoint + if self.save_best and key_score: + self._save_ckpt(runner, key_score) diff --git a/annotator/uniformer/mmcv/runner/hooks/hook.py b/annotator/uniformer/mmcv/runner/hooks/hook.py new file mode 100644 index 0000000000000000000000000000000000000000..b8855c107727ecf85b917c890fc8b7f6359238a4 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/hook.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from annotator.uniformer.mmcv.utils import Registry, is_method_overridden + +HOOKS = Registry('hook') + + +class Hook: + stages = ('before_run', 'before_train_epoch', 'before_train_iter', + 'after_train_iter', 'after_train_epoch', 'before_val_epoch', + 'before_val_iter', 'after_val_iter', 'after_val_epoch', + 'after_run') + + def before_run(self, runner): + pass + + def after_run(self, runner): + pass + + def before_epoch(self, runner): + pass + + def after_epoch(self, runner): + pass + + def before_iter(self, runner): + pass + + def after_iter(self, runner): + pass + + def before_train_epoch(self, runner): + self.before_epoch(runner) + + def before_val_epoch(self, runner): + self.before_epoch(runner) + + def after_train_epoch(self, runner): + self.after_epoch(runner) + + def after_val_epoch(self, runner): + self.after_epoch(runner) + + def before_train_iter(self, runner): + self.before_iter(runner) + + def before_val_iter(self, runner): + self.before_iter(runner) + + def after_train_iter(self, runner): + self.after_iter(runner) + + def after_val_iter(self, runner): + self.after_iter(runner) + + def every_n_epochs(self, runner, n): + return (runner.epoch + 1) % n == 0 if n > 0 else False + + def every_n_inner_iters(self, runner, n): + return (runner.inner_iter + 1) % n == 0 if n > 0 else False + + def every_n_iters(self, runner, n): + return (runner.iter + 1) % n == 0 if n > 0 else False + + def end_of_epoch(self, runner): + return runner.inner_iter + 1 == len(runner.data_loader) + + def is_last_epoch(self, runner): + return runner.epoch + 1 == runner._max_epochs + + def is_last_iter(self, runner): + return runner.iter + 1 == runner._max_iters + + def get_triggered_stages(self): + trigger_stages = set() + for stage in Hook.stages: + if is_method_overridden(stage, Hook, self): + trigger_stages.add(stage) + + # some methods will be triggered in multi stages + # use this dict to map method to stages. + method_stages_map = { + 'before_epoch': ['before_train_epoch', 'before_val_epoch'], + 'after_epoch': ['after_train_epoch', 'after_val_epoch'], + 'before_iter': ['before_train_iter', 'before_val_iter'], + 'after_iter': ['after_train_iter', 'after_val_iter'], + } + + for method, map_stages in method_stages_map.items(): + if is_method_overridden(method, Hook, self): + trigger_stages.update(map_stages) + + return [stage for stage in Hook.stages if stage in trigger_stages] diff --git a/annotator/uniformer/mmcv/runner/hooks/iter_timer.py b/annotator/uniformer/mmcv/runner/hooks/iter_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd5002fe85ffc6992155ac01003878064a1d9be --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/iter_timer.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time + +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class IterTimerHook(Hook): + + def before_epoch(self, runner): + self.t = time.time() + + def before_iter(self, runner): + runner.log_buffer.update({'data_time': time.time() - self.t}) + + def after_iter(self, runner): + runner.log_buffer.update({'time': time.time() - self.t}) + self.t = time.time() diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py b/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b6b345640a895368ac8a647afef6f24333d90e --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import LoggerHook +from .dvclive import DvcliveLoggerHook +from .mlflow import MlflowLoggerHook +from .neptune import NeptuneLoggerHook +from .pavi import PaviLoggerHook +from .tensorboard import TensorboardLoggerHook +from .text import TextLoggerHook +from .wandb import WandbLoggerHook + +__all__ = [ + 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', + 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook', + 'NeptuneLoggerHook', 'DvcliveLoggerHook' +] diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/base.py b/annotator/uniformer/mmcv/runner/hooks/logger/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f845256729458ced821762a1b8ef881e17ff9955 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/base.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numbers +from abc import ABCMeta, abstractmethod + +import numpy as np +import torch + +from ..hook import Hook + + +class LoggerHook(Hook): + """Base class for logger hooks. + + Args: + interval (int): Logging interval (every k iterations). + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + reset_flag (bool): Whether to clear the output buffer after logging. + by_epoch (bool): Whether EpochBasedRunner is used. + """ + + __metaclass__ = ABCMeta + + def __init__(self, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True): + self.interval = interval + self.ignore_last = ignore_last + self.reset_flag = reset_flag + self.by_epoch = by_epoch + + @abstractmethod + def log(self, runner): + pass + + @staticmethod + def is_scalar(val, include_np=True, include_torch=True): + """Tell the input variable is a scalar or not. + + Args: + val: Input variable. + include_np (bool): Whether include 0-d np.ndarray as a scalar. + include_torch (bool): Whether include 0-d torch.Tensor as a scalar. + + Returns: + bool: True or False. + """ + if isinstance(val, numbers.Number): + return True + elif include_np and isinstance(val, np.ndarray) and val.ndim == 0: + return True + elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1: + return True + else: + return False + + def get_mode(self, runner): + if runner.mode == 'train': + if 'time' in runner.log_buffer.output: + mode = 'train' + else: + mode = 'val' + elif runner.mode == 'val': + mode = 'val' + else: + raise ValueError(f"runner mode should be 'train' or 'val', " + f'but got {runner.mode}') + return mode + + def get_epoch(self, runner): + if runner.mode == 'train': + epoch = runner.epoch + 1 + elif runner.mode == 'val': + # normal val mode + # runner.epoch += 1 has been done before val workflow + epoch = runner.epoch + else: + raise ValueError(f"runner mode should be 'train' or 'val', " + f'but got {runner.mode}') + return epoch + + def get_iter(self, runner, inner_iter=False): + """Get the current training iteration step.""" + if self.by_epoch and inner_iter: + current_iter = runner.inner_iter + 1 + else: + current_iter = runner.iter + 1 + return current_iter + + def get_lr_tags(self, runner): + tags = {} + lrs = runner.current_lr() + if isinstance(lrs, dict): + for name, value in lrs.items(): + tags[f'learning_rate/{name}'] = value[0] + else: + tags['learning_rate'] = lrs[0] + return tags + + def get_momentum_tags(self, runner): + tags = {} + momentums = runner.current_momentum() + if isinstance(momentums, dict): + for name, value in momentums.items(): + tags[f'momentum/{name}'] = value[0] + else: + tags['momentum'] = momentums[0] + return tags + + def get_loggable_tags(self, + runner, + allow_scalar=True, + allow_text=False, + add_mode=True, + tags_to_skip=('time', 'data_time')): + tags = {} + for var, val in runner.log_buffer.output.items(): + if var in tags_to_skip: + continue + if self.is_scalar(val) and not allow_scalar: + continue + if isinstance(val, str) and not allow_text: + continue + if add_mode: + var = f'{self.get_mode(runner)}/{var}' + tags[var] = val + tags.update(self.get_lr_tags(runner)) + tags.update(self.get_momentum_tags(runner)) + return tags + + def before_run(self, runner): + for hook in runner.hooks[::-1]: + if isinstance(hook, LoggerHook): + hook.reset_flag = True + break + + def before_epoch(self, runner): + runner.log_buffer.clear() # clear logs of last epoch + + def after_train_iter(self, runner): + if self.by_epoch and self.every_n_inner_iters(runner, self.interval): + runner.log_buffer.average(self.interval) + elif not self.by_epoch and self.every_n_iters(runner, self.interval): + runner.log_buffer.average(self.interval) + elif self.end_of_epoch(runner) and not self.ignore_last: + # not precise but more stable + runner.log_buffer.average(self.interval) + + if runner.log_buffer.ready: + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() + + def after_train_epoch(self, runner): + if runner.log_buffer.ready: + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() + + def after_val_epoch(self, runner): + runner.log_buffer.average() + self.log(runner) + if self.reset_flag: + runner.log_buffer.clear_output() diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py b/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py new file mode 100644 index 0000000000000000000000000000000000000000..687cdc58c0336c92b1e4f9a410ba67ebaab2bc7a --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class DvcliveLoggerHook(LoggerHook): + """Class to log metrics with dvclive. + + It requires `dvclive`_ to be installed. + + Args: + path (str): Directory where dvclive will write TSV log files. + interval (int): Logging interval (every k iterations). + Default 10. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + Default: True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default: True. + by_epoch (bool): Whether EpochBasedRunner is used. + Default: True. + + .. _dvclive: + https://dvc.org/doc/dvclive + """ + + def __init__(self, + path, + interval=10, + ignore_last=True, + reset_flag=True, + by_epoch=True): + + super(DvcliveLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.path = path + self.import_dvclive() + + def import_dvclive(self): + try: + import dvclive + except ImportError: + raise ImportError( + 'Please run "pip install dvclive" to install dvclive') + self.dvclive = dvclive + + @master_only + def before_run(self, runner): + self.dvclive.init(self.path) + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner) + if tags: + for k, v in tags.items(): + self.dvclive.log(k, v, step=self.get_iter(runner)) diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py b/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a72592be47b534ce22573775fd5a7e8e86d72d --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class MlflowLoggerHook(LoggerHook): + + def __init__(self, + exp_name=None, + tags=None, + log_model=True, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True): + """Class to log metrics and (optionally) a trained model to MLflow. + + It requires `MLflow`_ to be installed. + + Args: + exp_name (str, optional): Name of the experiment to be used. + Default None. + If not None, set the active experiment. + If experiment does not exist, an experiment with provided name + will be created. + tags (dict of str: str, optional): Tags for the current run. + Default None. + If not None, set tags for the current run. + log_model (bool, optional): Whether to log an MLflow artifact. + Default True. + If True, log runner.model as an MLflow artifact + for the current run. + interval (int): Logging interval (every k iterations). + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + reset_flag (bool): Whether to clear the output buffer after logging + by_epoch (bool): Whether EpochBasedRunner is used. + + .. _MLflow: + https://www.mlflow.org/docs/latest/index.html + """ + super(MlflowLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.import_mlflow() + self.exp_name = exp_name + self.tags = tags + self.log_model = log_model + + def import_mlflow(self): + try: + import mlflow + import mlflow.pytorch as mlflow_pytorch + except ImportError: + raise ImportError( + 'Please run "pip install mlflow" to install mlflow') + self.mlflow = mlflow + self.mlflow_pytorch = mlflow_pytorch + + @master_only + def before_run(self, runner): + super(MlflowLoggerHook, self).before_run(runner) + if self.exp_name is not None: + self.mlflow.set_experiment(self.exp_name) + if self.tags is not None: + self.mlflow.set_tags(self.tags) + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner) + if tags: + self.mlflow.log_metrics(tags, step=self.get_iter(runner)) + + @master_only + def after_run(self, runner): + if self.log_model: + self.mlflow_pytorch.log_model(runner.model, 'models') diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py b/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py new file mode 100644 index 0000000000000000000000000000000000000000..7a38772b0c93a8608f32c6357b8616e77c139dc9 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class NeptuneLoggerHook(LoggerHook): + """Class to log metrics to NeptuneAI. + + It requires `neptune-client` to be installed. + + Args: + init_kwargs (dict): a dict contains the initialization keys as below: + - project (str): Name of a project in a form of + namespace/project_name. If None, the value of + NEPTUNE_PROJECT environment variable will be taken. + - api_token (str): User’s API token. + If None, the value of NEPTUNE_API_TOKEN environment + variable will be taken. Note: It is strongly recommended + to use NEPTUNE_API_TOKEN environment variable rather than + placing your API token in plain text in your source code. + - name (str, optional, default is 'Untitled'): Editable name of + the run. Name is displayed in the run's Details and in + Runs table as a column. + Check https://docs.neptune.ai/api-reference/neptune#init for + more init arguments. + interval (int): Logging interval (every k iterations). + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + reset_flag (bool): Whether to clear the output buffer after logging + by_epoch (bool): Whether EpochBasedRunner is used. + + .. _NeptuneAI: + https://docs.neptune.ai/you-should-know/logging-metadata + """ + + def __init__(self, + init_kwargs=None, + interval=10, + ignore_last=True, + reset_flag=True, + with_step=True, + by_epoch=True): + + super(NeptuneLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.import_neptune() + self.init_kwargs = init_kwargs + self.with_step = with_step + + def import_neptune(self): + try: + import neptune.new as neptune + except ImportError: + raise ImportError( + 'Please run "pip install neptune-client" to install neptune') + self.neptune = neptune + self.run = None + + @master_only + def before_run(self, runner): + if self.init_kwargs: + self.run = self.neptune.init(**self.init_kwargs) + else: + self.run = self.neptune.init() + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner) + if tags: + for tag_name, tag_value in tags.items(): + if self.with_step: + self.run[tag_name].log( + tag_value, step=self.get_iter(runner)) + else: + tags['global_step'] = self.get_iter(runner) + self.run[tag_name].log(tags) + + @master_only + def after_run(self, runner): + self.run.stop() diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py b/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py new file mode 100644 index 0000000000000000000000000000000000000000..1dcf146d8163aff1363e9764999b0a74d674a595 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import os.path as osp + +import torch +import yaml + +import annotator.uniformer.mmcv as mmcv +from ....parallel.utils import is_module_wrapper +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class PaviLoggerHook(LoggerHook): + + def __init__(self, + init_kwargs=None, + add_graph=False, + add_last_ckpt=False, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True, + img_key='img_info'): + super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag, + by_epoch) + self.init_kwargs = init_kwargs + self.add_graph = add_graph + self.add_last_ckpt = add_last_ckpt + self.img_key = img_key + + @master_only + def before_run(self, runner): + super(PaviLoggerHook, self).before_run(runner) + try: + from pavi import SummaryWriter + except ImportError: + raise ImportError('Please run "pip install pavi" to install pavi.') + + self.run_name = runner.work_dir.split('/')[-1] + + if not self.init_kwargs: + self.init_kwargs = dict() + self.init_kwargs['name'] = self.run_name + self.init_kwargs['model'] = runner._model_name + if runner.meta is not None: + if 'config_dict' in runner.meta: + config_dict = runner.meta['config_dict'] + assert isinstance( + config_dict, + dict), ('meta["config_dict"] has to be of a dict, ' + f'but got {type(config_dict)}') + elif 'config_file' in runner.meta: + config_file = runner.meta['config_file'] + config_dict = dict(mmcv.Config.fromfile(config_file)) + else: + config_dict = None + if config_dict is not None: + # 'max_.*iter' is parsed in pavi sdk as the maximum iterations + # to properly set up the progress bar. + config_dict = config_dict.copy() + config_dict.setdefault('max_iter', runner.max_iters) + # non-serializable values are first converted in + # mmcv.dump to json + config_dict = json.loads( + mmcv.dump(config_dict, file_format='json')) + session_text = yaml.dump(config_dict) + self.init_kwargs['session_text'] = session_text + self.writer = SummaryWriter(**self.init_kwargs) + + def get_step(self, runner): + """Get the total training step/epoch.""" + if self.get_mode(runner) == 'val' and self.by_epoch: + return self.get_epoch(runner) + else: + return self.get_iter(runner) + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner, add_mode=False) + if tags: + self.writer.add_scalars( + self.get_mode(runner), tags, self.get_step(runner)) + + @master_only + def after_run(self, runner): + if self.add_last_ckpt: + ckpt_path = osp.join(runner.work_dir, 'latest.pth') + if osp.islink(ckpt_path): + ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path)) + + if osp.isfile(ckpt_path): + # runner.epoch += 1 has been done before `after_run`. + iteration = runner.epoch if self.by_epoch else runner.iter + return self.writer.add_snapshot_file( + tag=self.run_name, + snapshot_file_path=ckpt_path, + iteration=iteration) + + # flush the buffer and send a task ending signal to Pavi + self.writer.close() + + @master_only + def before_epoch(self, runner): + if runner.epoch == 0 and self.add_graph: + if is_module_wrapper(runner.model): + _model = runner.model.module + else: + _model = runner.model + device = next(_model.parameters()).device + data = next(iter(runner.data_loader)) + image = data[self.img_key][0:1].to(device) + with torch.no_grad(): + self.writer.add_graph(_model, image) diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py b/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5011dc08def6c09eef86d3ce5b124c9fc5372 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class TensorboardLoggerHook(LoggerHook): + + def __init__(self, + log_dir=None, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True): + super(TensorboardLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.log_dir = log_dir + + @master_only + def before_run(self, runner): + super(TensorboardLoggerHook, self).before_run(runner) + if (TORCH_VERSION == 'parrots' + or digit_version(TORCH_VERSION) < digit_version('1.1')): + try: + from tensorboardX import SummaryWriter + except ImportError: + raise ImportError('Please install tensorboardX to use ' + 'TensorboardLoggerHook.') + else: + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError: + raise ImportError( + 'Please run "pip install future tensorboard" to install ' + 'the dependencies to use torch.utils.tensorboard ' + '(applicable to PyTorch 1.1 or higher)') + + if self.log_dir is None: + self.log_dir = osp.join(runner.work_dir, 'tf_logs') + self.writer = SummaryWriter(self.log_dir) + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner, allow_text=True) + for tag, val in tags.items(): + if isinstance(val, str): + self.writer.add_text(tag, val, self.get_iter(runner)) + else: + self.writer.add_scalar(tag, val, self.get_iter(runner)) + + @master_only + def after_run(self, runner): + self.writer.close() diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/text.py b/annotator/uniformer/mmcv/runner/hooks/logger/text.py new file mode 100644 index 0000000000000000000000000000000000000000..87b1a3eca9595a130121526f8b4c29915387ab35 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/text.py @@ -0,0 +1,256 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import os +import os.path as osp +from collections import OrderedDict + +import torch +import torch.distributed as dist + +import annotator.uniformer.mmcv as mmcv +from annotator.uniformer.mmcv.fileio.file_client import FileClient +from annotator.uniformer.mmcv.utils import is_tuple_of, scandir +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class TextLoggerHook(LoggerHook): + """Logger hook in text. + + In this logger hook, the information will be printed on terminal and + saved in json file. + + Args: + by_epoch (bool, optional): Whether EpochBasedRunner is used. + Default: True. + interval (int, optional): Logging interval (every k iterations). + Default: 10. + ignore_last (bool, optional): Ignore the log of last iterations in each + epoch if less than :attr:`interval`. Default: True. + reset_flag (bool, optional): Whether to clear the output buffer after + logging. Default: False. + interval_exp_name (int, optional): Logging interval for experiment + name. This feature is to help users conveniently get the experiment + information from screen or log file. Default: 1000. + out_dir (str, optional): Logs are saved in ``runner.work_dir`` default. + If ``out_dir`` is specified, logs will be copied to a new directory + which is the concatenation of ``out_dir`` and the last level + directory of ``runner.work_dir``. Default: None. + `New in version 1.3.16.` + out_suffix (str or tuple[str], optional): Those filenames ending with + ``out_suffix`` will be copied to ``out_dir``. + Default: ('.log.json', '.log', '.py'). + `New in version 1.3.16.` + keep_local (bool, optional): Whether to keep local log when + :attr:`out_dir` is specified. If False, the local log will be + removed. Default: True. + `New in version 1.3.16.` + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmcv.fileio.FileClient` for details. + Default: None. + `New in version 1.3.16.` + """ + + def __init__(self, + by_epoch=True, + interval=10, + ignore_last=True, + reset_flag=False, + interval_exp_name=1000, + out_dir=None, + out_suffix=('.log.json', '.log', '.py'), + keep_local=True, + file_client_args=None): + super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag, + by_epoch) + self.by_epoch = by_epoch + self.time_sec_tot = 0 + self.interval_exp_name = interval_exp_name + + if out_dir is None and file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" when `out_dir` is not' + 'specified.') + self.out_dir = out_dir + + if not (out_dir is None or isinstance(out_dir, str) + or is_tuple_of(out_dir, str)): + raise TypeError('out_dir should be "None" or string or tuple of ' + 'string, but got {out_dir}') + self.out_suffix = out_suffix + + self.keep_local = keep_local + self.file_client_args = file_client_args + if self.out_dir is not None: + self.file_client = FileClient.infer_client(file_client_args, + self.out_dir) + + def before_run(self, runner): + super(TextLoggerHook, self).before_run(runner) + + if self.out_dir is not None: + self.file_client = FileClient.infer_client(self.file_client_args, + self.out_dir) + # The final `self.out_dir` is the concatenation of `self.out_dir` + # and the last level directory of `runner.work_dir` + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + runner.logger.info( + (f'Text logs will be saved to {self.out_dir} by ' + f'{self.file_client.name} after the training process.')) + + self.start_iter = runner.iter + self.json_log_path = osp.join(runner.work_dir, + f'{runner.timestamp}.log.json') + if runner.meta is not None: + self._dump_log(runner.meta, runner) + + def _get_max_memory(self, runner): + device = getattr(runner.model, 'output_device', None) + mem = torch.cuda.max_memory_allocated(device=device) + mem_mb = torch.tensor([mem / (1024 * 1024)], + dtype=torch.int, + device=device) + if runner.world_size > 1: + dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) + return mem_mb.item() + + def _log_info(self, log_dict, runner): + # print exp name for users to distinguish experiments + # at every ``interval_exp_name`` iterations and the end of each epoch + if runner.meta is not None and 'exp_name' in runner.meta: + if (self.every_n_iters(runner, self.interval_exp_name)) or ( + self.by_epoch and self.end_of_epoch(runner)): + exp_info = f'Exp name: {runner.meta["exp_name"]}' + runner.logger.info(exp_info) + + if log_dict['mode'] == 'train': + if isinstance(log_dict['lr'], dict): + lr_str = [] + for k, val in log_dict['lr'].items(): + lr_str.append(f'lr_{k}: {val:.3e}') + lr_str = ' '.join(lr_str) + else: + lr_str = f'lr: {log_dict["lr"]:.3e}' + + # by epoch: Epoch [4][100/1000] + # by iter: Iter [100/100000] + if self.by_epoch: + log_str = f'Epoch [{log_dict["epoch"]}]' \ + f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t' + else: + log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t' + log_str += f'{lr_str}, ' + + if 'time' in log_dict.keys(): + self.time_sec_tot += (log_dict['time'] * self.interval) + time_sec_avg = self.time_sec_tot / ( + runner.iter - self.start_iter + 1) + eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + log_str += f'eta: {eta_str}, ' + log_str += f'time: {log_dict["time"]:.3f}, ' \ + f'data_time: {log_dict["data_time"]:.3f}, ' + # statistic memory + if torch.cuda.is_available(): + log_str += f'memory: {log_dict["memory"]}, ' + else: + # val/test time + # here 1000 is the length of the val dataloader + # by epoch: Epoch[val] [4][1000] + # by iter: Iter[val] [1000] + if self.by_epoch: + log_str = f'Epoch({log_dict["mode"]}) ' \ + f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t' + else: + log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t' + + log_items = [] + for name, val in log_dict.items(): + # TODO: resolve this hack + # these items have been in log_str + if name in [ + 'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time', + 'memory', 'epoch' + ]: + continue + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ', '.join(log_items) + + runner.logger.info(log_str) + + def _dump_log(self, log_dict, runner): + # dump log in json format + json_log = OrderedDict() + for k, v in log_dict.items(): + json_log[k] = self._round_float(v) + # only append log at last line + if runner.rank == 0: + with open(self.json_log_path, 'a+') as f: + mmcv.dump(json_log, f, file_format='json') + f.write('\n') + + def _round_float(self, items): + if isinstance(items, list): + return [self._round_float(item) for item in items] + elif isinstance(items, float): + return round(items, 5) + else: + return items + + def log(self, runner): + if 'eval_iter_num' in runner.log_buffer.output: + # this doesn't modify runner.iter and is regardless of by_epoch + cur_iter = runner.log_buffer.output.pop('eval_iter_num') + else: + cur_iter = self.get_iter(runner, inner_iter=True) + + log_dict = OrderedDict( + mode=self.get_mode(runner), + epoch=self.get_epoch(runner), + iter=cur_iter) + + # only record lr of the first param group + cur_lr = runner.current_lr() + if isinstance(cur_lr, list): + log_dict['lr'] = cur_lr[0] + else: + assert isinstance(cur_lr, dict) + log_dict['lr'] = {} + for k, lr_ in cur_lr.items(): + assert isinstance(lr_, list) + log_dict['lr'].update({k: lr_[0]}) + + if 'time' in runner.log_buffer.output: + # statistic memory + if torch.cuda.is_available(): + log_dict['memory'] = self._get_max_memory(runner) + + log_dict = dict(log_dict, **runner.log_buffer.output) + + self._log_info(log_dict, runner) + self._dump_log(log_dict, runner) + return log_dict + + def after_run(self, runner): + # copy or upload logs to self.out_dir + if self.out_dir is not None: + for filename in scandir(runner.work_dir, self.out_suffix, True): + local_filepath = osp.join(runner.work_dir, filename) + out_filepath = self.file_client.join_path( + self.out_dir, filename) + with open(local_filepath, 'r') as f: + self.file_client.put_text(f.read(), out_filepath) + + runner.logger.info( + (f'The file {local_filepath} has been uploaded to ' + f'{out_filepath}.')) + + if not self.keep_local: + os.remove(local_filepath) + runner.logger.info( + (f'{local_filepath} was removed due to the ' + '`self.keep_local=False`')) diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py b/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py new file mode 100644 index 0000000000000000000000000000000000000000..9f6808462eb79ab2b04806a5d9f0d3dd079b5ea9 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class WandbLoggerHook(LoggerHook): + + def __init__(self, + init_kwargs=None, + interval=10, + ignore_last=True, + reset_flag=False, + commit=True, + by_epoch=True, + with_step=True): + super(WandbLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.import_wandb() + self.init_kwargs = init_kwargs + self.commit = commit + self.with_step = with_step + + def import_wandb(self): + try: + import wandb + except ImportError: + raise ImportError( + 'Please run "pip install wandb" to install wandb') + self.wandb = wandb + + @master_only + def before_run(self, runner): + super(WandbLoggerHook, self).before_run(runner) + if self.wandb is None: + self.import_wandb() + if self.init_kwargs: + self.wandb.init(**self.init_kwargs) + else: + self.wandb.init() + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner) + if tags: + if self.with_step: + self.wandb.log( + tags, step=self.get_iter(runner), commit=self.commit) + else: + tags['global_step'] = self.get_iter(runner) + self.wandb.log(tags, commit=self.commit) + + @master_only + def after_run(self, runner): + self.wandb.join() diff --git a/annotator/uniformer/mmcv/runner/hooks/lr_updater.py b/annotator/uniformer/mmcv/runner/hooks/lr_updater.py new file mode 100644 index 0000000000000000000000000000000000000000..6365908ddf6070086de2ffc0afada46ed2f32256 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/lr_updater.py @@ -0,0 +1,670 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numbers +from math import cos, pi + +import annotator.uniformer.mmcv as mmcv +from .hook import HOOKS, Hook + + +class LrUpdaterHook(Hook): + """LR Scheduler in MMCV. + + Args: + by_epoch (bool): LR changes epoch by epoch + warmup (string): Type of warmup used. It can be None(use no warmup), + 'constant', 'linear' or 'exp' + warmup_iters (int): The number of iterations or epochs that warmup + lasts + warmup_ratio (float): LR used at the beginning of warmup equals to + warmup_ratio * initial_lr + warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters + means the number of epochs that warmup lasts, otherwise means the + number of iteration that warmup lasts + """ + + def __init__(self, + by_epoch=True, + warmup=None, + warmup_iters=0, + warmup_ratio=0.1, + warmup_by_epoch=False): + # validate the "warmup" argument + if warmup is not None: + if warmup not in ['constant', 'linear', 'exp']: + raise ValueError( + f'"{warmup}" is not a supported type for warming up, valid' + ' types are "constant" and "linear"') + if warmup is not None: + assert warmup_iters > 0, \ + '"warmup_iters" must be a positive integer' + assert 0 < warmup_ratio <= 1.0, \ + '"warmup_ratio" must be in range (0,1]' + + self.by_epoch = by_epoch + self.warmup = warmup + self.warmup_iters = warmup_iters + self.warmup_ratio = warmup_ratio + self.warmup_by_epoch = warmup_by_epoch + + if self.warmup_by_epoch: + self.warmup_epochs = self.warmup_iters + self.warmup_iters = None + else: + self.warmup_epochs = None + + self.base_lr = [] # initial lr for all param groups + self.regular_lr = [] # expected lr if no warming up is performed + + def _set_lr(self, runner, lr_groups): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + for param_group, lr in zip(optim.param_groups, lr_groups[k]): + param_group['lr'] = lr + else: + for param_group, lr in zip(runner.optimizer.param_groups, + lr_groups): + param_group['lr'] = lr + + def get_lr(self, runner, base_lr): + raise NotImplementedError + + def get_regular_lr(self, runner): + if isinstance(runner.optimizer, dict): + lr_groups = {} + for k in runner.optimizer.keys(): + _lr_group = [ + self.get_lr(runner, _base_lr) + for _base_lr in self.base_lr[k] + ] + lr_groups.update({k: _lr_group}) + + return lr_groups + else: + return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr] + + def get_warmup_lr(self, cur_iters): + + def _get_warmup_lr(cur_iters, regular_lr): + if self.warmup == 'constant': + warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr] + elif self.warmup == 'linear': + k = (1 - cur_iters / self.warmup_iters) * (1 - + self.warmup_ratio) + warmup_lr = [_lr * (1 - k) for _lr in regular_lr] + elif self.warmup == 'exp': + k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) + warmup_lr = [_lr * k for _lr in regular_lr] + return warmup_lr + + if isinstance(self.regular_lr, dict): + lr_groups = {} + for key, regular_lr in self.regular_lr.items(): + lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr) + return lr_groups + else: + return _get_warmup_lr(cur_iters, self.regular_lr) + + def before_run(self, runner): + # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, + # it will be set according to the optimizer params + if isinstance(runner.optimizer, dict): + self.base_lr = {} + for k, optim in runner.optimizer.items(): + for group in optim.param_groups: + group.setdefault('initial_lr', group['lr']) + _base_lr = [ + group['initial_lr'] for group in optim.param_groups + ] + self.base_lr.update({k: _base_lr}) + else: + for group in runner.optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + self.base_lr = [ + group['initial_lr'] for group in runner.optimizer.param_groups + ] + + def before_train_epoch(self, runner): + if self.warmup_iters is None: + epoch_len = len(runner.data_loader) + self.warmup_iters = self.warmup_epochs * epoch_len + + if not self.by_epoch: + return + + self.regular_lr = self.get_regular_lr(runner) + self._set_lr(runner, self.regular_lr) + + def before_train_iter(self, runner): + cur_iter = runner.iter + if not self.by_epoch: + self.regular_lr = self.get_regular_lr(runner) + if self.warmup is None or cur_iter >= self.warmup_iters: + self._set_lr(runner, self.regular_lr) + else: + warmup_lr = self.get_warmup_lr(cur_iter) + self._set_lr(runner, warmup_lr) + elif self.by_epoch: + if self.warmup is None or cur_iter > self.warmup_iters: + return + elif cur_iter == self.warmup_iters: + self._set_lr(runner, self.regular_lr) + else: + warmup_lr = self.get_warmup_lr(cur_iter) + self._set_lr(runner, warmup_lr) + + +@HOOKS.register_module() +class FixedLrUpdaterHook(LrUpdaterHook): + + def __init__(self, **kwargs): + super(FixedLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + return base_lr + + +@HOOKS.register_module() +class StepLrUpdaterHook(LrUpdaterHook): + """Step LR scheduler with min_lr clipping. + + Args: + step (int | list[int]): Step to decay the LR. If an int value is given, + regard it as the decay interval. If a list is given, decay LR at + these steps. + gamma (float, optional): Decay LR ratio. Default: 0.1. + min_lr (float, optional): Minimum LR value to keep. If LR after decay + is lower than `min_lr`, it will be clipped to this value. If None + is given, we don't perform lr clipping. Default: None. + """ + + def __init__(self, step, gamma=0.1, min_lr=None, **kwargs): + if isinstance(step, list): + assert mmcv.is_list_of(step, int) + assert all([s > 0 for s in step]) + elif isinstance(step, int): + assert step > 0 + else: + raise TypeError('"step" must be a list or integer') + self.step = step + self.gamma = gamma + self.min_lr = min_lr + super(StepLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + progress = runner.epoch if self.by_epoch else runner.iter + + # calculate exponential term + if isinstance(self.step, int): + exp = progress // self.step + else: + exp = len(self.step) + for i, s in enumerate(self.step): + if progress < s: + exp = i + break + + lr = base_lr * (self.gamma**exp) + if self.min_lr is not None: + # clip to a minimum value + lr = max(lr, self.min_lr) + return lr + + +@HOOKS.register_module() +class ExpLrUpdaterHook(LrUpdaterHook): + + def __init__(self, gamma, **kwargs): + self.gamma = gamma + super(ExpLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + progress = runner.epoch if self.by_epoch else runner.iter + return base_lr * self.gamma**progress + + +@HOOKS.register_module() +class PolyLrUpdaterHook(LrUpdaterHook): + + def __init__(self, power=1., min_lr=0., **kwargs): + self.power = power + self.min_lr = min_lr + super(PolyLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + coeff = (1 - progress / max_progress)**self.power + return (base_lr - self.min_lr) * coeff + self.min_lr + + +@HOOKS.register_module() +class InvLrUpdaterHook(LrUpdaterHook): + + def __init__(self, gamma, power=1., **kwargs): + self.gamma = gamma + self.power = power + super(InvLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + progress = runner.epoch if self.by_epoch else runner.iter + return base_lr * (1 + self.gamma * progress)**(-self.power) + + +@HOOKS.register_module() +class CosineAnnealingLrUpdaterHook(LrUpdaterHook): + + def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs): + assert (min_lr is None) ^ (min_lr_ratio is None) + self.min_lr = min_lr + self.min_lr_ratio = min_lr_ratio + super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + + if self.min_lr_ratio is not None: + target_lr = base_lr * self.min_lr_ratio + else: + target_lr = self.min_lr + return annealing_cos(base_lr, target_lr, progress / max_progress) + + +@HOOKS.register_module() +class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook): + """Flat + Cosine lr schedule. + + Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501 + + Args: + start_percent (float): When to start annealing the learning rate + after the percentage of the total training steps. + The value should be in range [0, 1). + Default: 0.75 + min_lr (float, optional): The minimum lr. Default: None. + min_lr_ratio (float, optional): The ratio of minimum lr to the base lr. + Either `min_lr` or `min_lr_ratio` should be specified. + Default: None. + """ + + def __init__(self, + start_percent=0.75, + min_lr=None, + min_lr_ratio=None, + **kwargs): + assert (min_lr is None) ^ (min_lr_ratio is None) + if start_percent < 0 or start_percent > 1 or not isinstance( + start_percent, float): + raise ValueError( + 'expected float between 0 and 1 start_percent, but ' + f'got {start_percent}') + self.start_percent = start_percent + self.min_lr = min_lr + self.min_lr_ratio = min_lr_ratio + super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + if self.by_epoch: + start = round(runner.max_epochs * self.start_percent) + progress = runner.epoch - start + max_progress = runner.max_epochs - start + else: + start = round(runner.max_iters * self.start_percent) + progress = runner.iter - start + max_progress = runner.max_iters - start + + if self.min_lr_ratio is not None: + target_lr = base_lr * self.min_lr_ratio + else: + target_lr = self.min_lr + + if progress < 0: + return base_lr + else: + return annealing_cos(base_lr, target_lr, progress / max_progress) + + +@HOOKS.register_module() +class CosineRestartLrUpdaterHook(LrUpdaterHook): + """Cosine annealing with restarts learning rate scheme. + + Args: + periods (list[int]): Periods for each cosine anneling cycle. + restart_weights (list[float], optional): Restart weights at each + restart iteration. Default: [1]. + min_lr (float, optional): The minimum lr. Default: None. + min_lr_ratio (float, optional): The ratio of minimum lr to the base lr. + Either `min_lr` or `min_lr_ratio` should be specified. + Default: None. + """ + + def __init__(self, + periods, + restart_weights=[1], + min_lr=None, + min_lr_ratio=None, + **kwargs): + assert (min_lr is None) ^ (min_lr_ratio is None) + self.periods = periods + self.min_lr = min_lr + self.min_lr_ratio = min_lr_ratio + self.restart_weights = restart_weights + assert (len(self.periods) == len(self.restart_weights) + ), 'periods and restart_weights should have the same length.' + super(CosineRestartLrUpdaterHook, self).__init__(**kwargs) + + self.cumulative_periods = [ + sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) + ] + + def get_lr(self, runner, base_lr): + if self.by_epoch: + progress = runner.epoch + else: + progress = runner.iter + + if self.min_lr_ratio is not None: + target_lr = base_lr * self.min_lr_ratio + else: + target_lr = self.min_lr + + idx = get_position_from_periods(progress, self.cumulative_periods) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1] + current_periods = self.periods[idx] + + alpha = min((progress - nearest_restart) / current_periods, 1) + return annealing_cos(base_lr, target_lr, alpha, current_weight) + + +def get_position_from_periods(iteration, cumulative_periods): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_periods = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 3. + + Args: + iteration (int): Current iteration. + cumulative_periods (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_periods): + if iteration < period: + return i + raise ValueError(f'Current iteration {iteration} exceeds ' + f'cumulative_periods {cumulative_periods}') + + +@HOOKS.register_module() +class CyclicLrUpdaterHook(LrUpdaterHook): + """Cyclic LR Scheduler. + + Implement the cyclical learning rate policy (CLR) described in + https://arxiv.org/pdf/1506.01186.pdf + + Different from the original paper, we use cosine annealing rather than + triangular policy inside a cycle. This improves the performance in the + 3D detection area. + + Args: + by_epoch (bool): Whether to update LR by epoch. + target_ratio (tuple[float]): Relative ratio of the highest LR and the + lowest LR to the initial LR. + cyclic_times (int): Number of cycles during training + step_ratio_up (float): The ratio of the increasing process of LR in + the total cycle. + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. Default: 'cos'. + """ + + def __init__(self, + by_epoch=False, + target_ratio=(10, 1e-4), + cyclic_times=1, + step_ratio_up=0.4, + anneal_strategy='cos', + **kwargs): + if isinstance(target_ratio, float): + target_ratio = (target_ratio, target_ratio / 1e5) + elif isinstance(target_ratio, tuple): + target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \ + if len(target_ratio) == 1 else target_ratio + else: + raise ValueError('target_ratio should be either float ' + f'or tuple, got {type(target_ratio)}') + + assert len(target_ratio) == 2, \ + '"target_ratio" must be list or tuple of two floats' + assert 0 <= step_ratio_up < 1.0, \ + '"step_ratio_up" must be in range [0,1)' + + self.target_ratio = target_ratio + self.cyclic_times = cyclic_times + self.step_ratio_up = step_ratio_up + self.lr_phases = [] # init lr_phases + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must be one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + + assert not by_epoch, \ + 'currently only support "by_epoch" = False' + super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs) + + def before_run(self, runner): + super(CyclicLrUpdaterHook, self).before_run(runner) + # initiate lr_phases + # total lr_phases are separated as up and down + max_iter_per_phase = runner.max_iters // self.cyclic_times + iter_up_phase = int(self.step_ratio_up * max_iter_per_phase) + self.lr_phases.append( + [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]]) + self.lr_phases.append([ + iter_up_phase, max_iter_per_phase, max_iter_per_phase, + self.target_ratio[0], self.target_ratio[1] + ]) + + def get_lr(self, runner, base_lr): + curr_iter = runner.iter + for (start_iter, end_iter, max_iter_per_phase, start_ratio, + end_ratio) in self.lr_phases: + curr_iter %= max_iter_per_phase + if start_iter <= curr_iter < end_iter: + progress = curr_iter - start_iter + return self.anneal_func(base_lr * start_ratio, + base_lr * end_ratio, + progress / (end_iter - start_iter)) + + +@HOOKS.register_module() +class OneCycleLrUpdaterHook(LrUpdaterHook): + """One Cycle LR Scheduler. + + The 1cycle learning rate policy changes the learning rate after every + batch. The one cycle learning rate policy is described in + https://arxiv.org/pdf/1708.07120.pdf + + Args: + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int, optional): The total number of steps in the cycle. + Note that if a value is not provided here, it will be the max_iter + of runner. Default: None. + pct_start (float): The percentage of the cycle (in number of steps) + spent increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. + Default: 'cos' + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + three_phase (bool): If three_phase is True, use a third phase of the + schedule to annihilate the learning rate according to + final_div_factor instead of modifying the second phase (the first + two phases will be symmetrical about the step indicated by + pct_start). + Default: False + """ + + def __init__(self, + max_lr, + total_steps=None, + pct_start=0.3, + anneal_strategy='cos', + div_factor=25, + final_div_factor=1e4, + three_phase=False, + **kwargs): + # validate by_epoch, currently only support by_epoch = False + if 'by_epoch' not in kwargs: + kwargs['by_epoch'] = False + else: + assert not kwargs['by_epoch'], \ + 'currently only support "by_epoch" = False' + if not isinstance(max_lr, (numbers.Number, list, dict)): + raise ValueError('the type of max_lr must be the one of list or ' + f'dict, but got {type(max_lr)}') + self._max_lr = max_lr + if total_steps is not None: + if not isinstance(total_steps, int): + raise ValueError('the type of total_steps must be int, but' + f'got {type(total_steps)}') + self.total_steps = total_steps + # validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError('expected float between 0 and 1 pct_start, but ' + f'got {pct_start}') + self.pct_start = pct_start + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must be one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + self.div_factor = div_factor + self.final_div_factor = final_div_factor + self.three_phase = three_phase + self.lr_phases = [] # init lr_phases + super(OneCycleLrUpdaterHook, self).__init__(**kwargs) + + def before_run(self, runner): + if hasattr(self, 'total_steps'): + total_steps = self.total_steps + else: + total_steps = runner.max_iters + if total_steps < runner.max_iters: + raise ValueError( + 'The total steps must be greater than or equal to max ' + f'iterations {runner.max_iters} of runner, but total steps ' + f'is {total_steps}.') + + if isinstance(runner.optimizer, dict): + self.base_lr = {} + for k, optim in runner.optimizer.items(): + _max_lr = format_param(k, optim, self._max_lr) + self.base_lr[k] = [lr / self.div_factor for lr in _max_lr] + for group, lr in zip(optim.param_groups, self.base_lr[k]): + group.setdefault('initial_lr', lr) + else: + k = type(runner.optimizer).__name__ + _max_lr = format_param(k, runner.optimizer, self._max_lr) + self.base_lr = [lr / self.div_factor for lr in _max_lr] + for group, lr in zip(runner.optimizer.param_groups, self.base_lr): + group.setdefault('initial_lr', lr) + + if self.three_phase: + self.lr_phases.append( + [float(self.pct_start * total_steps) - 1, 1, self.div_factor]) + self.lr_phases.append([ + float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1 + ]) + self.lr_phases.append( + [total_steps - 1, 1, 1 / self.final_div_factor]) + else: + self.lr_phases.append( + [float(self.pct_start * total_steps) - 1, 1, self.div_factor]) + self.lr_phases.append( + [total_steps - 1, self.div_factor, 1 / self.final_div_factor]) + + def get_lr(self, runner, base_lr): + curr_iter = runner.iter + start_iter = 0 + for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases): + if curr_iter <= end_iter: + pct = (curr_iter - start_iter) / (end_iter - start_iter) + lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr, + pct) + break + start_iter = end_iter + return lr + + +def annealing_cos(start, end, factor, weight=1): + """Calculate annealing cos learning rate. + + Cosine anneal from `weight * start + (1 - weight) * end` to `end` as + percentage goes from 0.0 to 1.0. + + Args: + start (float): The starting learning rate of the cosine annealing. + end (float): The ending learing rate of the cosine annealing. + factor (float): The coefficient of `pi` when calculating the current + percentage. Range from 0.0 to 1.0. + weight (float, optional): The combination factor of `start` and `end` + when calculating the actual starting learning rate. Default to 1. + """ + cos_out = cos(pi * factor) + 1 + return end + 0.5 * weight * (start - end) * cos_out + + +def annealing_linear(start, end, factor): + """Calculate annealing linear learning rate. + + Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0. + + Args: + start (float): The starting learning rate of the linear annealing. + end (float): The ending learing rate of the linear annealing. + factor (float): The coefficient of `pi` when calculating the current + percentage. Range from 0.0 to 1.0. + """ + return start + (end - start) * factor + + +def format_param(name, optim, param): + if isinstance(param, numbers.Number): + return [param] * len(optim.param_groups) + elif isinstance(param, (list, tuple)): # multi param groups + if len(param) != len(optim.param_groups): + raise ValueError(f'expected {len(optim.param_groups)} ' + f'values for {name}, got {len(param)}') + return param + else: # multi optimizers + if name not in param: + raise KeyError(f'{name} is not found in {param.keys()}') + return param[name] diff --git a/annotator/uniformer/mmcv/runner/hooks/memory.py b/annotator/uniformer/mmcv/runner/hooks/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..70cf9a838fb314e3bd3c07aadbc00921a81e83ed --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/memory.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class EmptyCacheHook(Hook): + + def __init__(self, before_epoch=False, after_epoch=True, after_iter=False): + self._before_epoch = before_epoch + self._after_epoch = after_epoch + self._after_iter = after_iter + + def after_iter(self, runner): + if self._after_iter: + torch.cuda.empty_cache() + + def before_epoch(self, runner): + if self._before_epoch: + torch.cuda.empty_cache() + + def after_epoch(self, runner): + if self._after_epoch: + torch.cuda.empty_cache() diff --git a/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py b/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py new file mode 100644 index 0000000000000000000000000000000000000000..60437756ceedf06055ec349df69a25465738d3f0 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py @@ -0,0 +1,493 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import annotator.uniformer.mmcv as mmcv +from .hook import HOOKS, Hook +from .lr_updater import annealing_cos, annealing_linear, format_param + + +class MomentumUpdaterHook(Hook): + + def __init__(self, + by_epoch=True, + warmup=None, + warmup_iters=0, + warmup_ratio=0.9): + # validate the "warmup" argument + if warmup is not None: + if warmup not in ['constant', 'linear', 'exp']: + raise ValueError( + f'"{warmup}" is not a supported type for warming up, valid' + ' types are "constant" and "linear"') + if warmup is not None: + assert warmup_iters > 0, \ + '"warmup_iters" must be a positive integer' + assert 0 < warmup_ratio <= 1.0, \ + '"warmup_momentum" must be in range (0,1]' + + self.by_epoch = by_epoch + self.warmup = warmup + self.warmup_iters = warmup_iters + self.warmup_ratio = warmup_ratio + + self.base_momentum = [] # initial momentum for all param groups + self.regular_momentum = [ + ] # expected momentum if no warming up is performed + + def _set_momentum(self, runner, momentum_groups): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + for param_group, mom in zip(optim.param_groups, + momentum_groups[k]): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + else: + for param_group, mom in zip(runner.optimizer.param_groups, + momentum_groups): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + + def get_momentum(self, runner, base_momentum): + raise NotImplementedError + + def get_regular_momentum(self, runner): + if isinstance(runner.optimizer, dict): + momentum_groups = {} + for k in runner.optimizer.keys(): + _momentum_group = [ + self.get_momentum(runner, _base_momentum) + for _base_momentum in self.base_momentum[k] + ] + momentum_groups.update({k: _momentum_group}) + return momentum_groups + else: + return [ + self.get_momentum(runner, _base_momentum) + for _base_momentum in self.base_momentum + ] + + def get_warmup_momentum(self, cur_iters): + + def _get_warmup_momentum(cur_iters, regular_momentum): + if self.warmup == 'constant': + warmup_momentum = [ + _momentum / self.warmup_ratio + for _momentum in self.regular_momentum + ] + elif self.warmup == 'linear': + k = (1 - cur_iters / self.warmup_iters) * (1 - + self.warmup_ratio) + warmup_momentum = [ + _momentum / (1 - k) for _momentum in self.regular_mom + ] + elif self.warmup == 'exp': + k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) + warmup_momentum = [ + _momentum / k for _momentum in self.regular_mom + ] + return warmup_momentum + + if isinstance(self.regular_momentum, dict): + momentum_groups = {} + for key, regular_momentum in self.regular_momentum.items(): + momentum_groups[key] = _get_warmup_momentum( + cur_iters, regular_momentum) + return momentum_groups + else: + return _get_warmup_momentum(cur_iters, self.regular_momentum) + + def before_run(self, runner): + # NOTE: when resuming from a checkpoint, + # if 'initial_momentum' is not saved, + # it will be set according to the optimizer params + if isinstance(runner.optimizer, dict): + self.base_momentum = {} + for k, optim in runner.optimizer.items(): + for group in optim.param_groups: + if 'momentum' in group.keys(): + group.setdefault('initial_momentum', group['momentum']) + else: + group.setdefault('initial_momentum', group['betas'][0]) + _base_momentum = [ + group['initial_momentum'] for group in optim.param_groups + ] + self.base_momentum.update({k: _base_momentum}) + else: + for group in runner.optimizer.param_groups: + if 'momentum' in group.keys(): + group.setdefault('initial_momentum', group['momentum']) + else: + group.setdefault('initial_momentum', group['betas'][0]) + self.base_momentum = [ + group['initial_momentum'] + for group in runner.optimizer.param_groups + ] + + def before_train_epoch(self, runner): + if not self.by_epoch: + return + self.regular_mom = self.get_regular_momentum(runner) + self._set_momentum(runner, self.regular_mom) + + def before_train_iter(self, runner): + cur_iter = runner.iter + if not self.by_epoch: + self.regular_mom = self.get_regular_momentum(runner) + if self.warmup is None or cur_iter >= self.warmup_iters: + self._set_momentum(runner, self.regular_mom) + else: + warmup_momentum = self.get_warmup_momentum(cur_iter) + self._set_momentum(runner, warmup_momentum) + elif self.by_epoch: + if self.warmup is None or cur_iter > self.warmup_iters: + return + elif cur_iter == self.warmup_iters: + self._set_momentum(runner, self.regular_mom) + else: + warmup_momentum = self.get_warmup_momentum(cur_iter) + self._set_momentum(runner, warmup_momentum) + + +@HOOKS.register_module() +class StepMomentumUpdaterHook(MomentumUpdaterHook): + """Step momentum scheduler with min value clipping. + + Args: + step (int | list[int]): Step to decay the momentum. If an int value is + given, regard it as the decay interval. If a list is given, decay + momentum at these steps. + gamma (float, optional): Decay momentum ratio. Default: 0.5. + min_momentum (float, optional): Minimum momentum value to keep. If + momentum after decay is lower than this value, it will be clipped + accordingly. If None is given, we don't perform lr clipping. + Default: None. + """ + + def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs): + if isinstance(step, list): + assert mmcv.is_list_of(step, int) + assert all([s > 0 for s in step]) + elif isinstance(step, int): + assert step > 0 + else: + raise TypeError('"step" must be a list or integer') + self.step = step + self.gamma = gamma + self.min_momentum = min_momentum + super(StepMomentumUpdaterHook, self).__init__(**kwargs) + + def get_momentum(self, runner, base_momentum): + progress = runner.epoch if self.by_epoch else runner.iter + + # calculate exponential term + if isinstance(self.step, int): + exp = progress // self.step + else: + exp = len(self.step) + for i, s in enumerate(self.step): + if progress < s: + exp = i + break + + momentum = base_momentum * (self.gamma**exp) + if self.min_momentum is not None: + # clip to a minimum value + momentum = max(momentum, self.min_momentum) + return momentum + + +@HOOKS.register_module() +class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook): + + def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs): + assert (min_momentum is None) ^ (min_momentum_ratio is None) + self.min_momentum = min_momentum + self.min_momentum_ratio = min_momentum_ratio + super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs) + + def get_momentum(self, runner, base_momentum): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + if self.min_momentum_ratio is not None: + target_momentum = base_momentum * self.min_momentum_ratio + else: + target_momentum = self.min_momentum + return annealing_cos(base_momentum, target_momentum, + progress / max_progress) + + +@HOOKS.register_module() +class CyclicMomentumUpdaterHook(MomentumUpdaterHook): + """Cyclic momentum Scheduler. + + Implement the cyclical momentum scheduler policy described in + https://arxiv.org/pdf/1708.07120.pdf + + This momentum scheduler usually used together with the CyclicLRUpdater + to improve the performance in the 3D detection area. + + Attributes: + target_ratio (tuple[float]): Relative ratio of the lowest momentum and + the highest momentum to the initial momentum. + cyclic_times (int): Number of cycles during training + step_ratio_up (float): The ratio of the increasing process of momentum + in the total cycle. + by_epoch (bool): Whether to update momentum by epoch. + """ + + def __init__(self, + by_epoch=False, + target_ratio=(0.85 / 0.95, 1), + cyclic_times=1, + step_ratio_up=0.4, + **kwargs): + if isinstance(target_ratio, float): + target_ratio = (target_ratio, target_ratio / 1e5) + elif isinstance(target_ratio, tuple): + target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \ + if len(target_ratio) == 1 else target_ratio + else: + raise ValueError('target_ratio should be either float ' + f'or tuple, got {type(target_ratio)}') + + assert len(target_ratio) == 2, \ + '"target_ratio" must be list or tuple of two floats' + assert 0 <= step_ratio_up < 1.0, \ + '"step_ratio_up" must be in range [0,1)' + + self.target_ratio = target_ratio + self.cyclic_times = cyclic_times + self.step_ratio_up = step_ratio_up + self.momentum_phases = [] # init momentum_phases + # currently only support by_epoch=False + assert not by_epoch, \ + 'currently only support "by_epoch" = False' + super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs) + + def before_run(self, runner): + super(CyclicMomentumUpdaterHook, self).before_run(runner) + # initiate momentum_phases + # total momentum_phases are separated as up and down + max_iter_per_phase = runner.max_iters // self.cyclic_times + iter_up_phase = int(self.step_ratio_up * max_iter_per_phase) + self.momentum_phases.append( + [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]]) + self.momentum_phases.append([ + iter_up_phase, max_iter_per_phase, max_iter_per_phase, + self.target_ratio[0], self.target_ratio[1] + ]) + + def get_momentum(self, runner, base_momentum): + curr_iter = runner.iter + for (start_iter, end_iter, max_iter_per_phase, start_ratio, + end_ratio) in self.momentum_phases: + curr_iter %= max_iter_per_phase + if start_iter <= curr_iter < end_iter: + progress = curr_iter - start_iter + return annealing_cos(base_momentum * start_ratio, + base_momentum * end_ratio, + progress / (end_iter - start_iter)) + + +@HOOKS.register_module() +class OneCycleMomentumUpdaterHook(MomentumUpdaterHook): + """OneCycle momentum Scheduler. + + This momentum scheduler usually used together with the OneCycleLrUpdater + to improve the performance. + + Args: + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is + 'max_momentum' and learning rate is 'base_lr' + Default: 0.95 + pct_start (float): The percentage of the cycle (in number of steps) + spent increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: 'cos' for cosine annealing, + 'linear' for linear annealing. + Default: 'cos' + three_phase (bool): If three_phase is True, use a third phase of the + schedule to annihilate the learning rate according to + final_div_factor instead of modifying the second phase (the first + two phases will be symmetrical about the step indicated by + pct_start). + Default: False + """ + + def __init__(self, + base_momentum=0.85, + max_momentum=0.95, + pct_start=0.3, + anneal_strategy='cos', + three_phase=False, + **kwargs): + # validate by_epoch, currently only support by_epoch=False + if 'by_epoch' not in kwargs: + kwargs['by_epoch'] = False + else: + assert not kwargs['by_epoch'], \ + 'currently only support "by_epoch" = False' + if not isinstance(base_momentum, (float, list, dict)): + raise ValueError('base_momentum must be the type among of float,' + 'list or dict.') + self._base_momentum = base_momentum + if not isinstance(max_momentum, (float, list, dict)): + raise ValueError('max_momentum must be the type among of float,' + 'list or dict.') + self._max_momentum = max_momentum + # validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError('Expected float between 0 and 1 pct_start, but ' + f'got {pct_start}') + self.pct_start = pct_start + # validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError('anneal_strategy must by one of "cos" or ' + f'"linear", instead got {anneal_strategy}') + elif anneal_strategy == 'cos': + self.anneal_func = annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = annealing_linear + self.three_phase = three_phase + self.momentum_phases = [] # init momentum_phases + super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs) + + def before_run(self, runner): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + if ('momentum' not in optim.defaults + and 'betas' not in optim.defaults): + raise ValueError('optimizer must support momentum with' + 'option enabled') + self.use_beta1 = 'betas' in optim.defaults + _base_momentum = format_param(k, optim, self._base_momentum) + _max_momentum = format_param(k, optim, self._max_momentum) + for group, b_momentum, m_momentum in zip( + optim.param_groups, _base_momentum, _max_momentum): + if self.use_beta1: + _, beta2 = group['betas'] + group['betas'] = (m_momentum, beta2) + else: + group['momentum'] = m_momentum + group['base_momentum'] = b_momentum + group['max_momentum'] = m_momentum + else: + optim = runner.optimizer + if ('momentum' not in optim.defaults + and 'betas' not in optim.defaults): + raise ValueError('optimizer must support momentum with' + 'option enabled') + self.use_beta1 = 'betas' in optim.defaults + k = type(optim).__name__ + _base_momentum = format_param(k, optim, self._base_momentum) + _max_momentum = format_param(k, optim, self._max_momentum) + for group, b_momentum, m_momentum in zip(optim.param_groups, + _base_momentum, + _max_momentum): + if self.use_beta1: + _, beta2 = group['betas'] + group['betas'] = (m_momentum, beta2) + else: + group['momentum'] = m_momentum + group['base_momentum'] = b_momentum + group['max_momentum'] = m_momentum + + if self.three_phase: + self.momentum_phases.append({ + 'end_iter': + float(self.pct_start * runner.max_iters) - 1, + 'start_momentum': + 'max_momentum', + 'end_momentum': + 'base_momentum' + }) + self.momentum_phases.append({ + 'end_iter': + float(2 * self.pct_start * runner.max_iters) - 2, + 'start_momentum': + 'base_momentum', + 'end_momentum': + 'max_momentum' + }) + self.momentum_phases.append({ + 'end_iter': runner.max_iters - 1, + 'start_momentum': 'max_momentum', + 'end_momentum': 'max_momentum' + }) + else: + self.momentum_phases.append({ + 'end_iter': + float(self.pct_start * runner.max_iters) - 1, + 'start_momentum': + 'max_momentum', + 'end_momentum': + 'base_momentum' + }) + self.momentum_phases.append({ + 'end_iter': runner.max_iters - 1, + 'start_momentum': 'base_momentum', + 'end_momentum': 'max_momentum' + }) + + def _set_momentum(self, runner, momentum_groups): + if isinstance(runner.optimizer, dict): + for k, optim in runner.optimizer.items(): + for param_group, mom in zip(optim.param_groups, + momentum_groups[k]): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + else: + for param_group, mom in zip(runner.optimizer.param_groups, + momentum_groups): + if 'momentum' in param_group.keys(): + param_group['momentum'] = mom + elif 'betas' in param_group.keys(): + param_group['betas'] = (mom, param_group['betas'][1]) + + def get_momentum(self, runner, param_group): + curr_iter = runner.iter + start_iter = 0 + for i, phase in enumerate(self.momentum_phases): + end_iter = phase['end_iter'] + if curr_iter <= end_iter or i == len(self.momentum_phases) - 1: + pct = (curr_iter - start_iter) / (end_iter - start_iter) + momentum = self.anneal_func( + param_group[phase['start_momentum']], + param_group[phase['end_momentum']], pct) + break + start_iter = end_iter + return momentum + + def get_regular_momentum(self, runner): + if isinstance(runner.optimizer, dict): + momentum_groups = {} + for k, optim in runner.optimizer.items(): + _momentum_group = [ + self.get_momentum(runner, param_group) + for param_group in optim.param_groups + ] + momentum_groups.update({k: _momentum_group}) + return momentum_groups + else: + momentum_groups = [] + for param_group in runner.optimizer.param_groups: + momentum_groups.append(self.get_momentum(runner, param_group)) + return momentum_groups diff --git a/annotator/uniformer/mmcv/runner/hooks/optimizer.py b/annotator/uniformer/mmcv/runner/hooks/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4ef3e9ff8f9c6926e32bdf027612267b64ed80df --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/optimizer.py @@ -0,0 +1,508 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from collections import defaultdict +from itertools import chain + +from torch.nn.utils import clip_grad + +from annotator.uniformer.mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version +from ..dist_utils import allreduce_grads +from ..fp16_utils import LossScaler, wrap_fp16_model +from .hook import HOOKS, Hook + +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported + # and used; otherwise, auto fp16 will adopt mmcv's implementation. + from torch.cuda.amp import GradScaler +except ImportError: + pass + + +@HOOKS.register_module() +class OptimizerHook(Hook): + + def __init__(self, grad_clip=None): + self.grad_clip = grad_clip + + def clip_grads(self, params): + params = list( + filter(lambda p: p.requires_grad and p.grad is not None, params)) + if len(params) > 0: + return clip_grad.clip_grad_norm_(params, **self.grad_clip) + + def after_train_iter(self, runner): + runner.optimizer.zero_grad() + runner.outputs['loss'].backward() + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() + + +@HOOKS.register_module() +class GradientCumulativeOptimizerHook(OptimizerHook): + """Optimizer Hook implements multi-iters gradient cumulating. + + Args: + cumulative_iters (int, optional): Num of gradient cumulative iters. + The optimizer will step every `cumulative_iters` iters. + Defaults to 1. + + Examples: + >>> # Use cumulative_iters to simulate a large batch size + >>> # It is helpful when the hardware cannot handle a large batch size. + >>> loader = DataLoader(data, batch_size=64) + >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4) + >>> # almost equals to + >>> loader = DataLoader(data, batch_size=256) + >>> optim_hook = OptimizerHook() + """ + + def __init__(self, cumulative_iters=1, **kwargs): + super(GradientCumulativeOptimizerHook, self).__init__(**kwargs) + + assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \ + f'cumulative_iters only accepts positive int, but got ' \ + f'{type(cumulative_iters)} instead.' + + self.cumulative_iters = cumulative_iters + self.divisible_iters = 0 + self.remainder_iters = 0 + self.initialized = False + + def has_batch_norm(self, module): + if isinstance(module, _BatchNorm): + return True + for m in module.children(): + if self.has_batch_norm(m): + return True + return False + + def _init(self, runner): + if runner.iter % self.cumulative_iters != 0: + runner.logger.warning( + 'Resume iter number is not divisible by cumulative_iters in ' + 'GradientCumulativeOptimizerHook, which means the gradient of ' + 'some iters is lost and the result may be influenced slightly.' + ) + + if self.has_batch_norm(runner.model) and self.cumulative_iters > 1: + runner.logger.warning( + 'GradientCumulativeOptimizerHook may slightly decrease ' + 'performance if the model has BatchNorm layers.') + + residual_iters = runner.max_iters - runner.iter + + self.divisible_iters = ( + residual_iters // self.cumulative_iters * self.cumulative_iters) + self.remainder_iters = residual_iters - self.divisible_iters + + self.initialized = True + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + loss = runner.outputs['loss'] + loss = loss / loss_factor + loss.backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() + runner.optimizer.zero_grad() + + +if (TORCH_VERSION != 'parrots' + and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): + + @HOOKS.register_module() + class Fp16OptimizerHook(OptimizerHook): + """FP16 optimizer hook (using PyTorch's implementation). + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, + to take care of the optimization procedure. + + Args: + loss_scale (float | str | dict): Scale factor configuration. + If loss_scale is a float, static loss scaling will be used with + the specified scale. If loss_scale is a string, it must be + 'dynamic', then dynamic loss scaling will be used. + It can also be a dict containing arguments of GradScalar. + Defaults to 512. For Pytorch >= 1.6, mmcv uses official + implementation of GradScaler. If you use a dict version of + loss_scale to create GradScaler, please refer to: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler + for the parameters. + + Examples: + >>> loss_scale = dict( + ... init_scale=65536.0, + ... growth_factor=2.0, + ... backoff_factor=0.5, + ... growth_interval=2000 + ... ) + >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale) + """ + + def __init__(self, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + loss_scale=512., + distributed=True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + self._scale_update_param = None + if loss_scale == 'dynamic': + self.loss_scaler = GradScaler() + elif isinstance(loss_scale, float): + self._scale_update_param = loss_scale + self.loss_scaler = GradScaler(init_scale=loss_scale) + elif isinstance(loss_scale, dict): + self.loss_scaler = GradScaler(**loss_scale) + else: + raise ValueError('loss_scale must be of type float, dict, or ' + f'"dynamic", got {loss_scale}') + + def before_run(self, runner): + """Preparing steps before Mixed Precision Training.""" + # wrap model mode to fp16 + wrap_fp16_model(runner.model) + # resume from state dict + if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']: + scaler_state_dict = runner.meta['fp16']['loss_scaler'] + self.loss_scaler.load_state_dict(scaler_state_dict) + + def copy_grads_to_fp32(self, fp16_net, fp32_weights): + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, + fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new( + fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net, fp32_weights): + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), + fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner): + """Backward optimization steps for Mixed Precision Training. For + dynamic loss scaling, please refer to + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler. + + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients. + 3. Unscale the optimizer’s gradient tensors. + 4. Call optimizer.step() and update scale factor. + 5. Save loss_scaler state_dict for resume purpose. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + + self.loss_scaler.scale(runner.outputs['loss']).backward() + self.loss_scaler.unscale_(runner.optimizer) + # grad clip + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # backward and update scaler + self.loss_scaler.step(runner.optimizer) + self.loss_scaler.update(self._scale_update_param) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + @HOOKS.register_module() + class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, + Fp16OptimizerHook): + """Fp16 optimizer Hook (using PyTorch's implementation) implements + multi-iters gradient cumulating. + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, + to take care of the optimization procedure. + """ + + def __init__(self, *args, **kwargs): + super(GradientCumulativeFp16OptimizerHook, + self).__init__(*args, **kwargs) + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + loss = runner.outputs['loss'] + loss = loss / loss_factor + + self.loss_scaler.scale(loss).backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + # copy fp16 grads in the model to fp32 params in the optimizer + self.loss_scaler.unscale_(runner.optimizer) + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters()) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + + # backward and update scaler + self.loss_scaler.step(runner.optimizer) + self.loss_scaler.update(self._scale_update_param) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + # clear grads + runner.model.zero_grad() + runner.optimizer.zero_grad() + +else: + + @HOOKS.register_module() + class Fp16OptimizerHook(OptimizerHook): + """FP16 optimizer hook (mmcv's implementation). + + The steps of fp16 optimizer is as follows. + 1. Scale the loss value. + 2. BP in the fp16 model. + 2. Copy gradients from fp16 model to fp32 weights. + 3. Update fp32 weights. + 4. Copy updated parameters from fp32 weights to fp16 model. + + Refer to https://arxiv.org/abs/1710.03740 for more details. + + Args: + loss_scale (float | str | dict): Scale factor configuration. + If loss_scale is a float, static loss scaling will be used with + the specified scale. If loss_scale is a string, it must be + 'dynamic', then dynamic loss scaling will be used. + It can also be a dict containing arguments of LossScaler. + Defaults to 512. + """ + + def __init__(self, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + loss_scale=512., + distributed=True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + if loss_scale == 'dynamic': + self.loss_scaler = LossScaler(mode='dynamic') + elif isinstance(loss_scale, float): + self.loss_scaler = LossScaler( + init_scale=loss_scale, mode='static') + elif isinstance(loss_scale, dict): + self.loss_scaler = LossScaler(**loss_scale) + else: + raise ValueError('loss_scale must be of type float, dict, or ' + f'"dynamic", got {loss_scale}') + + def before_run(self, runner): + """Preparing steps before Mixed Precision Training. + + 1. Make a master copy of fp32 weights for optimization. + 2. Convert the main model from fp32 to fp16. + """ + # keep a copy of fp32 weights + old_groups = runner.optimizer.param_groups + runner.optimizer.param_groups = copy.deepcopy( + runner.optimizer.param_groups) + state = defaultdict(dict) + p_map = { + old_p: p + for old_p, p in zip( + chain(*(g['params'] for g in old_groups)), + chain(*(g['params'] + for g in runner.optimizer.param_groups))) + } + for k, v in runner.optimizer.state.items(): + state[p_map[k]] = v + runner.optimizer.state = state + # convert model to fp16 + wrap_fp16_model(runner.model) + # resume from state dict + if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']: + scaler_state_dict = runner.meta['fp16']['loss_scaler'] + self.loss_scaler.load_state_dict(scaler_state_dict) + + def copy_grads_to_fp32(self, fp16_net, fp32_weights): + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, + fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new( + fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net, fp32_weights): + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), + fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner): + """Backward optimization steps for Mixed Precision Training. For + dynamic loss scaling, please refer `loss_scalar.py` + + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients (fp16). + 3. Copy gradients from the model to the fp32 weight copy. + 4. Scale the gradients back and update the fp32 weight copy. + 5. Copy back the params from fp32 weight copy to the fp16 model. + 6. Save loss_scaler state_dict for resume purpose. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + # scale the loss value + scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale + scaled_loss.backward() + # copy fp16 grads in the model to fp32 params in the optimizer + + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, + self.bucket_size_mb) + + has_overflow = self.loss_scaler.has_overflow(fp32_weights) + # if has overflow, skip this iteration + if not has_overflow: + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scaler.loss_scale) + if self.grad_clip is not None: + grad_norm = self.clip_grads(fp32_weights) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + self.loss_scaler.update_scale(has_overflow) + if has_overflow: + runner.logger.warning('Check overflow, downscale loss scale ' + f'to {self.loss_scaler.cur_scale}') + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + @HOOKS.register_module() + class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, + Fp16OptimizerHook): + """Fp16 optimizer Hook (using mmcv implementation) implements multi- + iters gradient cumulating.""" + + def __init__(self, *args, **kwargs): + super(GradientCumulativeFp16OptimizerHook, + self).__init__(*args, **kwargs) + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + if runner.iter < self.divisible_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + + loss = runner.outputs['loss'] + loss = loss / loss_factor + + # scale the loss value + scaled_loss = loss * self.loss_scaler.loss_scale + scaled_loss.backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + # copy fp16 grads in the model to fp32 params in the optimizer + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, + self.bucket_size_mb) + + has_overflow = self.loss_scaler.has_overflow(fp32_weights) + # if has overflow, skip this iteration + if not has_overflow: + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scaler.loss_scale) + if self.grad_clip is not None: + grad_norm = self.clip_grads(fp32_weights) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + else: + runner.logger.warning( + 'Check overflow, downscale loss scale ' + f'to {self.loss_scaler.cur_scale}') + + self.loss_scaler.update_scale(has_overflow) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + # clear grads + runner.model.zero_grad() + runner.optimizer.zero_grad() diff --git a/annotator/uniformer/mmcv/runner/hooks/profiler.py b/annotator/uniformer/mmcv/runner/hooks/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..b70236997eec59c2209ef351ae38863b4112d0ec --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/profiler.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Callable, List, Optional, Union + +import torch + +from ..dist_utils import master_only +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class ProfilerHook(Hook): + """Profiler to analyze performance during training. + + PyTorch Profiler is a tool that allows the collection of the performance + metrics during the training. More details on Profiler can be found at + https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile + + Args: + by_epoch (bool): Profile performance by epoch or by iteration. + Default: True. + profile_iters (int): Number of iterations for profiling. + If ``by_epoch=True``, profile_iters indicates that they are the + first profile_iters epochs at the beginning of the + training, otherwise it indicates the first profile_iters + iterations. Default: 1. + activities (list[str]): List of activity groups (CPU, CUDA) to use in + profiling. Default: ['cpu', 'cuda']. + schedule (dict, optional): Config of generating the callable schedule. + if schedule is None, profiler will not add step markers into the + trace and table view. Default: None. + on_trace_ready (callable, dict): Either a handler or a dict of generate + handler. Default: None. + record_shapes (bool): Save information about operator's input shapes. + Default: False. + profile_memory (bool): Track tensor memory allocation/deallocation. + Default: False. + with_stack (bool): Record source information (file and line number) + for the ops. Default: False. + with_flops (bool): Use formula to estimate the FLOPS of specific + operators (matrix multiplication and 2D convolution). + Default: False. + json_trace_path (str, optional): Exports the collected trace in Chrome + JSON format. Default: None. + + Example: + >>> runner = ... # instantiate a Runner + >>> # tensorboard trace + >>> trace_config = dict(type='tb_trace', dir_name='work_dir') + >>> profiler_config = dict(on_trace_ready=trace_config) + >>> runner.register_profiler_hook(profiler_config) + >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)]) + """ + + def __init__(self, + by_epoch: bool = True, + profile_iters: int = 1, + activities: List[str] = ['cpu', 'cuda'], + schedule: Optional[dict] = None, + on_trace_ready: Optional[Union[Callable, dict]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + json_trace_path: Optional[str] = None) -> None: + try: + from torch import profiler # torch version >= 1.8.1 + except ImportError: + raise ImportError('profiler is the new feature of torch1.8.1, ' + f'but your version is {torch.__version__}') + + assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.' + self.by_epoch = by_epoch + + if profile_iters < 1: + raise ValueError('profile_iters should be greater than 0, but got ' + f'{profile_iters}') + self.profile_iters = profile_iters + + if not isinstance(activities, list): + raise ValueError( + f'activities should be list, but got {type(activities)}') + self.activities = [] + for activity in activities: + activity = activity.lower() + if activity == 'cpu': + self.activities.append(profiler.ProfilerActivity.CPU) + elif activity == 'cuda': + self.activities.append(profiler.ProfilerActivity.CUDA) + else: + raise ValueError( + f'activity should be "cpu" or "cuda", but got {activity}') + + if schedule is not None: + self.schedule = profiler.schedule(**schedule) + else: + self.schedule = None + + self.on_trace_ready = on_trace_ready + self.record_shapes = record_shapes + self.profile_memory = profile_memory + self.with_stack = with_stack + self.with_flops = with_flops + self.json_trace_path = json_trace_path + + @master_only + def before_run(self, runner): + if self.by_epoch and runner.max_epochs < self.profile_iters: + raise ValueError('self.profile_iters should not be greater than ' + f'{runner.max_epochs}') + + if not self.by_epoch and runner.max_iters < self.profile_iters: + raise ValueError('self.profile_iters should not be greater than ' + f'{runner.max_iters}') + + if callable(self.on_trace_ready): # handler + _on_trace_ready = self.on_trace_ready + elif isinstance(self.on_trace_ready, dict): # config of handler + trace_cfg = self.on_trace_ready.copy() + trace_type = trace_cfg.pop('type') # log_trace handler + if trace_type == 'log_trace': + + def _log_handler(prof): + print(prof.key_averages().table(**trace_cfg)) + + _on_trace_ready = _log_handler + elif trace_type == 'tb_trace': # tensorboard_trace handler + try: + import torch_tb_profiler # noqa: F401 + except ImportError: + raise ImportError('please run "pip install ' + 'torch-tb-profiler" to install ' + 'torch_tb_profiler') + _on_trace_ready = torch.profiler.tensorboard_trace_handler( + **trace_cfg) + else: + raise ValueError('trace_type should be "log_trace" or ' + f'"tb_trace", but got {trace_type}') + elif self.on_trace_ready is None: + _on_trace_ready = None # type: ignore + else: + raise ValueError('on_trace_ready should be handler, dict or None, ' + f'but got {type(self.on_trace_ready)}') + + if runner.max_epochs > 1: + warnings.warn(f'profiler will profile {runner.max_epochs} epochs ' + 'instead of 1 epoch. Since profiler will slow down ' + 'the training, it is recommended to train 1 epoch ' + 'with ProfilerHook and adjust your setting according' + ' to the profiler summary. During normal training ' + '(epoch > 1), you may disable the ProfilerHook.') + + self.profiler = torch.profiler.profile( + activities=self.activities, + schedule=self.schedule, + on_trace_ready=_on_trace_ready, + record_shapes=self.record_shapes, + profile_memory=self.profile_memory, + with_stack=self.with_stack, + with_flops=self.with_flops) + + self.profiler.__enter__() + runner.logger.info('profiler is profiling...') + + @master_only + def after_train_epoch(self, runner): + if self.by_epoch and runner.epoch == self.profile_iters - 1: + runner.logger.info('profiler may take a few minutes...') + self.profiler.__exit__(None, None, None) + if self.json_trace_path is not None: + self.profiler.export_chrome_trace(self.json_trace_path) + + @master_only + def after_train_iter(self, runner): + self.profiler.step() + if not self.by_epoch and runner.iter == self.profile_iters - 1: + runner.logger.info('profiler may take a few minutes...') + self.profiler.__exit__(None, None, None) + if self.json_trace_path is not None: + self.profiler.export_chrome_trace(self.json_trace_path) diff --git a/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py b/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0dc6bdd8df5775857028aaed5444c0f59caf80 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class DistSamplerSeedHook(Hook): + """Data-loading sampler for distributed training. + + When distributed training, it is only useful in conjunction with + :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same + purpose with :obj:`IterLoader`. + """ + + def before_epoch(self, runner): + if hasattr(runner.data_loader.sampler, 'set_epoch'): + # in case the data loader uses `SequentialSampler` in Pytorch + runner.data_loader.sampler.set_epoch(runner.epoch) + elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'): + # batch sampler in pytorch warps the sampler as its attributes. + runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch) diff --git a/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py b/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..6376b7ff894280cb2782243b25e8973650591577 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..dist_utils import allreduce_params +from .hook import HOOKS, Hook + + +@HOOKS.register_module() +class SyncBuffersHook(Hook): + """Synchronize model buffers such as running_mean and running_var in BN at + the end of each epoch. + + Args: + distributed (bool): Whether distributed training is used. It is + effective only for distributed training. Defaults to True. + """ + + def __init__(self, distributed=True): + self.distributed = distributed + + def after_epoch(self, runner): + """All-reduce model buffers at the end of each epoch.""" + if self.distributed: + allreduce_params(runner.model.buffers()) diff --git a/annotator/uniformer/mmcv/runner/iter_based_runner.py b/annotator/uniformer/mmcv/runner/iter_based_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..1df4de8c0285669dec9b014dfd1f3dd1600f0831 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/iter_based_runner.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import platform +import shutil +import time +import warnings + +import torch +from torch.optim import Optimizer + +import annotator.uniformer.mmcv as mmcv +from .base_runner import BaseRunner +from .builder import RUNNERS +from .checkpoint import save_checkpoint +from .hooks import IterTimerHook +from .utils import get_host_info + + +class IterLoader: + + def __init__(self, dataloader): + self._dataloader = dataloader + self.iter_loader = iter(self._dataloader) + self._epoch = 0 + + @property + def epoch(self): + return self._epoch + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + if hasattr(self._dataloader.sampler, 'set_epoch'): + self._dataloader.sampler.set_epoch(self._epoch) + time.sleep(2) # Prevent possible deadlock during epoch transition + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) + + return data + + def __len__(self): + return len(self._dataloader) + + +@RUNNERS.register_module() +class IterBasedRunner(BaseRunner): + """Iteration-based Runner. + + This runner train models iteration by iteration. + """ + + def train(self, data_loader, **kwargs): + self.model.train() + self.mode = 'train' + self.data_loader = data_loader + self._epoch = data_loader.epoch + data_batch = next(data_loader) + self.call_hook('before_train_iter') + outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('model.train_step() must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs + self.call_hook('after_train_iter') + self._inner_iter += 1 + self._iter += 1 + + @torch.no_grad() + def val(self, data_loader, **kwargs): + self.model.eval() + self.mode = 'val' + self.data_loader = data_loader + data_batch = next(data_loader) + self.call_hook('before_val_iter') + outputs = self.model.val_step(data_batch, **kwargs) + if not isinstance(outputs, dict): + raise TypeError('model.val_step() must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs + self.call_hook('after_val_iter') + self._inner_iter += 1 + + def run(self, data_loaders, workflow, max_iters=None, **kwargs): + """Start running. + + Args: + data_loaders (list[:obj:`DataLoader`]): Dataloaders for training + and validation. + workflow (list[tuple]): A list of (phase, iters) to specify the + running order and iterations. E.g, [('train', 10000), + ('val', 1000)] means running 10000 iterations for training and + 1000 iterations for validation, iteratively. + """ + assert isinstance(data_loaders, list) + assert mmcv.is_list_of(workflow, tuple) + assert len(data_loaders) == len(workflow) + if max_iters is not None: + warnings.warn( + 'setting max_iters in run is deprecated, ' + 'please set max_iters in runner_config', DeprecationWarning) + self._max_iters = max_iters + assert self._max_iters is not None, ( + 'max_iters must be specified during instantiation') + + work_dir = self.work_dir if self.work_dir is not None else 'NONE' + self.logger.info('Start running, host: %s, work_dir: %s', + get_host_info(), work_dir) + self.logger.info('Hooks will be executed in the following order:\n%s', + self.get_hook_info()) + self.logger.info('workflow: %s, max: %d iters', workflow, + self._max_iters) + self.call_hook('before_run') + + iter_loaders = [IterLoader(x) for x in data_loaders] + + self.call_hook('before_epoch') + + while self.iter < self._max_iters: + for i, flow in enumerate(workflow): + self._inner_iter = 0 + mode, iters = flow + if not isinstance(mode, str) or not hasattr(self, mode): + raise ValueError( + 'runner has no method named "{}" to run a workflow'. + format(mode)) + iter_runner = getattr(self, mode) + for _ in range(iters): + if mode == 'train' and self.iter >= self._max_iters: + break + iter_runner(iter_loaders[i], **kwargs) + + time.sleep(1) # wait for some hooks like loggers to finish + self.call_hook('after_epoch') + self.call_hook('after_run') + + def resume(self, + checkpoint, + resume_optimizer=True, + map_location='default'): + """Resume model from checkpoint. + + Args: + checkpoint (str): Checkpoint to resume from. + resume_optimizer (bool, optional): Whether resume the optimizer(s) + if the checkpoint file includes optimizer(s). Default to True. + map_location (str, optional): Same as :func:`torch.load`. + Default to 'default'. + """ + if map_location == 'default': + device_id = torch.cuda.current_device() + checkpoint = self.load_checkpoint( + checkpoint, + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + checkpoint = self.load_checkpoint( + checkpoint, map_location=map_location) + + self._epoch = checkpoint['meta']['epoch'] + self._iter = checkpoint['meta']['iter'] + self._inner_iter = checkpoint['meta']['iter'] + if 'optimizer' in checkpoint and resume_optimizer: + if isinstance(self.optimizer, Optimizer): + self.optimizer.load_state_dict(checkpoint['optimizer']) + elif isinstance(self.optimizer, dict): + for k in self.optimizer.keys(): + self.optimizer[k].load_state_dict( + checkpoint['optimizer'][k]) + else: + raise TypeError( + 'Optimizer should be dict or torch.optim.Optimizer ' + f'but got {type(self.optimizer)}') + + self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') + + def save_checkpoint(self, + out_dir, + filename_tmpl='iter_{}.pth', + meta=None, + save_optimizer=True, + create_symlink=True): + """Save checkpoint to file. + + Args: + out_dir (str): Directory to save checkpoint files. + filename_tmpl (str, optional): Checkpoint file template. + Defaults to 'iter_{}.pth'. + meta (dict, optional): Metadata to be saved in checkpoint. + Defaults to None. + save_optimizer (bool, optional): Whether save optimizer. + Defaults to True. + create_symlink (bool, optional): Whether create symlink to the + latest checkpoint file. Defaults to True. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError( + f'meta should be a dict or None, but got {type(meta)}') + if self.meta is not None: + meta.update(self.meta) + # Note: meta.update(self.meta) should be done before + # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise + # there will be problems with resumed checkpoints. + # More details in https://github.com/open-mmlab/mmcv/pull/1108 + meta.update(epoch=self.epoch + 1, iter=self.iter) + + filename = filename_tmpl.format(self.iter + 1) + filepath = osp.join(out_dir, filename) + optimizer = self.optimizer if save_optimizer else None + save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) + # in some environments, `os.symlink` is not supported, you may need to + # set `create_symlink` to False + if create_symlink: + dst_file = osp.join(out_dir, 'latest.pth') + if platform.system() != 'Windows': + mmcv.symlink(filename, dst_file) + else: + shutil.copy(filepath, dst_file) + + def register_training_hooks(self, + lr_config, + optimizer_config=None, + checkpoint_config=None, + log_config=None, + momentum_config=None, + custom_hooks_config=None): + """Register default hooks for iter-based training. + + Checkpoint hook, optimizer stepper hook and logger hooks will be set to + `by_epoch=False` by default. + + Default hooks include: + + +----------------------+-------------------------+ + | Hooks | Priority | + +======================+=========================+ + | LrUpdaterHook | VERY_HIGH (10) | + +----------------------+-------------------------+ + | MomentumUpdaterHook | HIGH (30) | + +----------------------+-------------------------+ + | OptimizerStepperHook | ABOVE_NORMAL (40) | + +----------------------+-------------------------+ + | CheckpointSaverHook | NORMAL (50) | + +----------------------+-------------------------+ + | IterTimerHook | LOW (70) | + +----------------------+-------------------------+ + | LoggerHook(s) | VERY_LOW (90) | + +----------------------+-------------------------+ + | CustomHook(s) | defaults to NORMAL (50) | + +----------------------+-------------------------+ + + If custom hooks have same priority with default hooks, custom hooks + will be triggered after default hooks. + """ + if checkpoint_config is not None: + checkpoint_config.setdefault('by_epoch', False) + if lr_config is not None: + lr_config.setdefault('by_epoch', False) + if log_config is not None: + for info in log_config['hooks']: + info.setdefault('by_epoch', False) + super(IterBasedRunner, self).register_training_hooks( + lr_config=lr_config, + momentum_config=momentum_config, + optimizer_config=optimizer_config, + checkpoint_config=checkpoint_config, + log_config=log_config, + timer_config=IterTimerHook(), + custom_hooks_config=custom_hooks_config) diff --git a/annotator/uniformer/mmcv/runner/log_buffer.py b/annotator/uniformer/mmcv/runner/log_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..d949e2941c5400088c7cd8a1dc893d8b233ae785 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/log_buffer.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +import numpy as np + + +class LogBuffer: + + def __init__(self): + self.val_history = OrderedDict() + self.n_history = OrderedDict() + self.output = OrderedDict() + self.ready = False + + def clear(self): + self.val_history.clear() + self.n_history.clear() + self.clear_output() + + def clear_output(self): + self.output.clear() + self.ready = False + + def update(self, vars, count=1): + assert isinstance(vars, dict) + for key, var in vars.items(): + if key not in self.val_history: + self.val_history[key] = [] + self.n_history[key] = [] + self.val_history[key].append(var) + self.n_history[key].append(count) + + def average(self, n=0): + """Average latest n values or all values.""" + assert n >= 0 + for key in self.val_history: + values = np.array(self.val_history[key][-n:]) + nums = np.array(self.n_history[key][-n:]) + avg = np.sum(values * nums) / np.sum(nums) + self.output[key] = avg + self.ready = True diff --git a/annotator/uniformer/mmcv/runner/optimizer/__init__.py b/annotator/uniformer/mmcv/runner/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53c34d0470992cbc374f29681fdd00dc0e57968d --- /dev/null +++ b/annotator/uniformer/mmcv/runner/optimizer/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer, + build_optimizer_constructor) +from .default_constructor import DefaultOptimizerConstructor + +__all__ = [ + 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', + 'build_optimizer', 'build_optimizer_constructor' +] diff --git a/annotator/uniformer/mmcv/runner/optimizer/builder.py b/annotator/uniformer/mmcv/runner/optimizer/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f9234eed8f1f186d9d8dfda34562157ee39bdb3a --- /dev/null +++ b/annotator/uniformer/mmcv/runner/optimizer/builder.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import inspect + +import torch + +from ...utils import Registry, build_from_cfg + +OPTIMIZERS = Registry('optimizer') +OPTIMIZER_BUILDERS = Registry('optimizer builder') + + +def register_torch_optimizers(): + torch_optimizers = [] + for module_name in dir(torch.optim): + if module_name.startswith('__'): + continue + _optim = getattr(torch.optim, module_name) + if inspect.isclass(_optim) and issubclass(_optim, + torch.optim.Optimizer): + OPTIMIZERS.register_module()(_optim) + torch_optimizers.append(module_name) + return torch_optimizers + + +TORCH_OPTIMIZERS = register_torch_optimizers() + + +def build_optimizer_constructor(cfg): + return build_from_cfg(cfg, OPTIMIZER_BUILDERS) + + +def build_optimizer(model, cfg): + optimizer_cfg = copy.deepcopy(cfg) + constructor_type = optimizer_cfg.pop('constructor', + 'DefaultOptimizerConstructor') + paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) + optim_constructor = build_optimizer_constructor( + dict( + type=constructor_type, + optimizer_cfg=optimizer_cfg, + paramwise_cfg=paramwise_cfg)) + optimizer = optim_constructor(model) + return optimizer diff --git a/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py b/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0da3503b75441738efe38d70352b55a210a34a --- /dev/null +++ b/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py @@ -0,0 +1,249 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +from torch.nn import GroupNorm, LayerNorm + +from annotator.uniformer.mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of +from annotator.uniformer.mmcv.utils.ext_loader import check_ops_exist +from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS + + +@OPTIMIZER_BUILDERS.register_module() +class DefaultOptimizerConstructor: + """Default constructor for optimizers. + + By default each parameter share the same optimizer settings, and we + provide an argument ``paramwise_cfg`` to specify parameter-wise settings. + It is a dict and may contain the following fields: + + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will + be ignored. It should be noted that the aforementioned ``key`` is the + longest key that is a substring of the name of the parameter. If there + are multiple matched keys with the same length, then the key with lower + alphabet order will be chosen. + ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult`` + and ``decay_mult``. See Example 2 below. + - ``bias_lr_mult`` (float): It will be multiplied to the learning + rate for all bias parameters (except for those in normalization + layers and offset layers of DCN). + - ``bias_decay_mult`` (float): It will be multiplied to the weight + decay for all bias parameters (except for those in + normalization layers, depthwise conv layers, offset layers of DCN). + - ``norm_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of normalization + layers. + - ``dwconv_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of depthwise conv + layers. + - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning + rate for parameters of offset layer in the deformable convs + of a model. + - ``bypass_duplicate`` (bool): If true, the duplicate parameters + would not be added into optimizer. Default: False. + + Note: + 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will + override the effect of ``bias_lr_mult`` in the bias of offset + layer. So be careful when using both ``bias_lr_mult`` and + ``dcn_offset_lr_mult``. If you wish to apply both of them to the + offset layer in deformable convs, set ``dcn_offset_lr_mult`` + to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``. + 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will + apply it to all the DCN layers in the model. So be careful when + the model contains multiple DCN layers in places other than + backbone. + + Args: + model (:obj:`nn.Module`): The model with parameters to be optimized. + optimizer_cfg (dict): The config dict of the optimizer. + Positional fields are + + - `type`: class name of the optimizer. + + Optional fields are + + - any arguments of the corresponding optimizer type, e.g., + lr, weight_decay, momentum, etc. + paramwise_cfg (dict, optional): Parameter-wise options. + + Example 1: + >>> model = torch.nn.modules.Conv1d(1, 1, 1) + >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, + >>> weight_decay=0.0001) + >>> paramwise_cfg = dict(norm_decay_mult=0.) + >>> optim_builder = DefaultOptimizerConstructor( + >>> optimizer_cfg, paramwise_cfg) + >>> optimizer = optim_builder(model) + + Example 2: + >>> # assume model have attribute model.backbone and model.cls_head + >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95) + >>> paramwise_cfg = dict(custom_keys={ + '.backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> optim_builder = DefaultOptimizerConstructor( + >>> optimizer_cfg, paramwise_cfg) + >>> optimizer = optim_builder(model) + >>> # Then the `lr` and `weight_decay` for model.backbone is + >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for + >>> # model.cls_head is (0.01, 0.95). + """ + + def __init__(self, optimizer_cfg, paramwise_cfg=None): + if not isinstance(optimizer_cfg, dict): + raise TypeError('optimizer_cfg should be a dict', + f'but got {type(optimizer_cfg)}') + self.optimizer_cfg = optimizer_cfg + self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg + self.base_lr = optimizer_cfg.get('lr', None) + self.base_wd = optimizer_cfg.get('weight_decay', None) + self._validate_cfg() + + def _validate_cfg(self): + if not isinstance(self.paramwise_cfg, dict): + raise TypeError('paramwise_cfg should be None or a dict, ' + f'but got {type(self.paramwise_cfg)}') + + if 'custom_keys' in self.paramwise_cfg: + if not isinstance(self.paramwise_cfg['custom_keys'], dict): + raise TypeError( + 'If specified, custom_keys must be a dict, ' + f'but got {type(self.paramwise_cfg["custom_keys"])}') + if self.base_wd is None: + for key in self.paramwise_cfg['custom_keys']: + if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]: + raise ValueError('base_wd should not be None') + + # get base lr and weight decay + # weight_decay must be explicitly specified if mult is specified + if ('bias_decay_mult' in self.paramwise_cfg + or 'norm_decay_mult' in self.paramwise_cfg + or 'dwconv_decay_mult' in self.paramwise_cfg): + if self.base_wd is None: + raise ValueError('base_wd should not be None') + + def _is_in(self, param_group, param_group_list): + assert is_list_of(param_group_list, dict) + param = set(param_group['params']) + param_set = set() + for group in param_group_list: + param_set.update(set(group['params'])) + + return not param.isdisjoint(param_set) + + def add_params(self, params, module, prefix='', is_dcn_module=None): + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + prefix (str): The prefix of the module + is_dcn_module (int|float|None): If the current module is a + submodule of DCN, `is_dcn_module` will be passed to + control conv_offset layer's learning rate. Defaults to None. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) + dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.) + bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) + dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + is_dwconv = ( + isinstance(module, torch.nn.Conv2d) + and module.in_channels == module.groups) + + for name, param in module.named_parameters(recurse=False): + param_group = {'params': [param]} + if not param.requires_grad: + params.append(param_group) + continue + if bypass_duplicate and self._is_in(param_group, params): + warnings.warn(f'{prefix} is duplicate. It is skipped since ' + f'bypass_duplicate={bypass_duplicate}') + continue + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + for key in sorted_keys: + if key in f'{prefix}.{name}': + is_custom = True + lr_mult = custom_keys[key].get('lr_mult', 1.) + param_group['lr'] = self.base_lr * lr_mult + if self.base_wd is not None: + decay_mult = custom_keys[key].get('decay_mult', 1.) + param_group['weight_decay'] = self.base_wd * decay_mult + break + + if not is_custom: + # bias_lr_mult affects all bias parameters + # except for norm.bias dcn.conv_offset.bias + if name == 'bias' and not (is_norm or is_dcn_module): + param_group['lr'] = self.base_lr * bias_lr_mult + + if (prefix.find('conv_offset') != -1 and is_dcn_module + and isinstance(module, torch.nn.Conv2d)): + # deal with both dcn_offset's bias & weight + param_group['lr'] = self.base_lr * dcn_offset_lr_mult + + # apply weight decay policies + if self.base_wd is not None: + # norm decay + if is_norm: + param_group[ + 'weight_decay'] = self.base_wd * norm_decay_mult + # depth-wise conv + elif is_dwconv: + param_group[ + 'weight_decay'] = self.base_wd * dwconv_decay_mult + # bias lr and decay + elif name == 'bias' and not is_dcn_module: + # TODO: current bias_decay_mult will have affect on DCN + param_group[ + 'weight_decay'] = self.base_wd * bias_decay_mult + params.append(param_group) + + if check_ops_exist(): + from annotator.uniformer.mmcv.ops import DeformConv2d, ModulatedDeformConv2d + is_dcn_module = isinstance(module, + (DeformConv2d, ModulatedDeformConv2d)) + else: + is_dcn_module = False + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}.{child_name}' if prefix else child_name + self.add_params( + params, + child_mod, + prefix=child_prefix, + is_dcn_module=is_dcn_module) + + def __call__(self, model): + if hasattr(model, 'module'): + model = model.module + + optimizer_cfg = self.optimizer_cfg.copy() + # if no paramwise option is specified, just use the global setting + if not self.paramwise_cfg: + optimizer_cfg['params'] = model.parameters() + return build_from_cfg(optimizer_cfg, OPTIMIZERS) + + # set param-wise lr and weight decay recursively + params = [] + self.add_params(params, model) + optimizer_cfg['params'] = params + + return build_from_cfg(optimizer_cfg, OPTIMIZERS) diff --git a/annotator/uniformer/mmcv/runner/priority.py b/annotator/uniformer/mmcv/runner/priority.py new file mode 100644 index 0000000000000000000000000000000000000000..64cc4e3a05f8d5b89ab6eb32461e6e80f1d62e67 --- /dev/null +++ b/annotator/uniformer/mmcv/runner/priority.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from enum import Enum + + +class Priority(Enum): + """Hook priority levels. + + +--------------+------------+ + | Level | Value | + +==============+============+ + | HIGHEST | 0 | + +--------------+------------+ + | VERY_HIGH | 10 | + +--------------+------------+ + | HIGH | 30 | + +--------------+------------+ + | ABOVE_NORMAL | 40 | + +--------------+------------+ + | NORMAL | 50 | + +--------------+------------+ + | BELOW_NORMAL | 60 | + +--------------+------------+ + | LOW | 70 | + +--------------+------------+ + | VERY_LOW | 90 | + +--------------+------------+ + | LOWEST | 100 | + +--------------+------------+ + """ + + HIGHEST = 0 + VERY_HIGH = 10 + HIGH = 30 + ABOVE_NORMAL = 40 + NORMAL = 50 + BELOW_NORMAL = 60 + LOW = 70 + VERY_LOW = 90 + LOWEST = 100 + + +def get_priority(priority): + """Get priority value. + + Args: + priority (int or str or :obj:`Priority`): Priority. + + Returns: + int: The priority value. + """ + if isinstance(priority, int): + if priority < 0 or priority > 100: + raise ValueError('priority must be between 0 and 100') + return priority + elif isinstance(priority, Priority): + return priority.value + elif isinstance(priority, str): + return Priority[priority.upper()].value + else: + raise TypeError('priority must be an integer or Priority enum value') diff --git a/annotator/uniformer/mmcv/runner/utils.py b/annotator/uniformer/mmcv/runner/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c5befb8e56ece50b5fecfd007b26f8a29124c0bd --- /dev/null +++ b/annotator/uniformer/mmcv/runner/utils.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import random +import sys +import time +import warnings +from getpass import getuser +from socket import gethostname + +import numpy as np +import torch + +import annotator.uniformer.mmcv as mmcv + + +def get_host_info(): + """Get hostname and username. + + Return empty string if exception raised, e.g. ``getpass.getuser()`` will + lead to error in docker container + """ + host = '' + try: + host = f'{getuser()}@{gethostname()}' + except Exception as e: + warnings.warn(f'Host or user not found: {str(e)}') + finally: + return host + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def obj_from_dict(info, parent=None, default_args=None): + """Initialize an object from dict. + + The dict must contain the key "type", which indicates the object type, it + can be either a string or type, such as "list" or ``list``. Remaining + fields are treated as the arguments for constructing the object. + + Args: + info (dict): Object types and arguments. + parent (:class:`module`): Module which may containing expected object + classes. + default_args (dict, optional): Default arguments for initializing the + object. + + Returns: + any type: Object built from the dict. + """ + assert isinstance(info, dict) and 'type' in info + assert isinstance(default_args, dict) or default_args is None + args = info.copy() + obj_type = args.pop('type') + if mmcv.is_str(obj_type): + if parent is not None: + obj_type = getattr(parent, obj_type) + else: + obj_type = sys.modules[obj_type] + elif not isinstance(obj_type, type): + raise TypeError('type must be a str or valid type, but ' + f'got {type(obj_type)}') + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + return obj_type(**args) + + +def set_random_seed(seed, deterministic=False, use_rank_shift=False): + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + rank_shift (bool): Whether to add rank number to the random seed to + have different random seed in different threads. Default: False. + """ + if use_rank_shift: + rank, _ = mmcv.runner.get_dist_info() + seed += rank + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/annotator/uniformer/mmcv/utils/__init__.py b/annotator/uniformer/mmcv/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..378a0068432a371af364de9d73785901c0f83383 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/__init__.py @@ -0,0 +1,69 @@ +# flake8: noqa +# Copyright (c) OpenMMLab. All rights reserved. +from .config import Config, ConfigDict, DictAction +from .misc import (check_prerequisites, concat_list, deprecated_api_warning, + has_method, import_modules_from_strings, is_list_of, + is_method_overridden, is_seq_of, is_str, is_tuple_of, + iter_cast, list_cast, requires_executable, requires_package, + slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, + to_ntuple, tuple_cast) +from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, + scandir, symlink) +from .progressbar import (ProgressBar, track_iter_progress, + track_parallel_progress, track_progress) +from .testing import (assert_attrs_equal, assert_dict_contains_subset, + assert_dict_has_keys, assert_is_norm_layer, + assert_keys_equal, assert_params_all_zeros, + check_python_script) +from .timer import Timer, TimerError, check_time +from .version_utils import digit_version, get_git_hash + +try: + import torch +except ImportError: + __all__ = [ + 'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast', + 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of', + 'slice_list', 'concat_list', 'check_prerequisites', 'requires_package', + 'requires_executable', 'is_filepath', 'fopen', 'check_file_exist', + 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar', + 'track_progress', 'track_iter_progress', 'track_parallel_progress', + 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning', + 'digit_version', 'get_git_hash', 'import_modules_from_strings', + 'assert_dict_contains_subset', 'assert_attrs_equal', + 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script', + 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', + 'is_method_overridden', 'has_method' + ] +else: + from .env import collect_env + from .logging import get_logger, print_log + from .parrots_jit import jit, skip_no_elena + from .parrots_wrapper import ( + TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader, + PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, + _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm, + _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home) + from .registry import Registry, build_from_cfg + from .trace import is_jit_tracing + __all__ = [ + 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', + 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast', + 'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list', + 'check_prerequisites', 'requires_package', 'requires_executable', + 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', + 'symlink', 'scandir', 'ProgressBar', 'track_progress', + 'track_iter_progress', 'track_parallel_progress', 'Registry', + 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm', + '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm', + '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd', + 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension', + 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION', + 'deprecated_api_warning', 'digit_version', 'get_git_hash', + 'import_modules_from_strings', 'jit', 'skip_no_elena', + 'assert_dict_contains_subset', 'assert_attrs_equal', + 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', + 'assert_params_all_zeros', 'check_python_script', + 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', + '_get_cuda_home', 'has_method' + ] diff --git a/annotator/uniformer/mmcv/utils/config.py b/annotator/uniformer/mmcv/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..17149353aefac6d737c67bb2f35a3a6cd2147b0a --- /dev/null +++ b/annotator/uniformer/mmcv/utils/config.py @@ -0,0 +1,688 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import copy +import os +import os.path as osp +import platform +import shutil +import sys +import tempfile +import uuid +import warnings +from argparse import Action, ArgumentParser +from collections import abc +from importlib import import_module + +from addict import Dict +from yapf.yapflib.yapf_api import FormatCode + +from .misc import import_modules_from_strings +from .path import check_file_exist + +if platform.system() == 'Windows': + import regex as re +else: + import re + +BASE_KEY = '_base_' +DELETE_KEY = '_delete_' +DEPRECATION_KEY = '_deprecation_' +RESERVED_KEYS = ['filename', 'text', 'pretty_text'] + + +class ConfigDict(Dict): + + def __missing__(self, name): + raise KeyError(name) + + def __getattr__(self, name): + try: + value = super(ConfigDict, self).__getattr__(name) + except KeyError: + ex = AttributeError(f"'{self.__class__.__name__}' object has no " + f"attribute '{name}'") + except Exception as e: + ex = e + else: + return value + raise ex + + +def add_args(parser, cfg, prefix=''): + for k, v in cfg.items(): + if isinstance(v, str): + parser.add_argument('--' + prefix + k) + elif isinstance(v, int): + parser.add_argument('--' + prefix + k, type=int) + elif isinstance(v, float): + parser.add_argument('--' + prefix + k, type=float) + elif isinstance(v, bool): + parser.add_argument('--' + prefix + k, action='store_true') + elif isinstance(v, dict): + add_args(parser, v, prefix + k + '.') + elif isinstance(v, abc.Iterable): + parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') + else: + print(f'cannot parse key {prefix + k} of type {type(v)}') + return parser + + +class Config: + """A facility for config and config files. + + It supports common file formats as configs: python/json/yaml. The interface + is the same as a dict object and also allows access config values as + attributes. + + Example: + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> cfg.a + 1 + >>> cfg.b + {'b1': [0, 1]} + >>> cfg.b.b1 + [0, 1] + >>> cfg = Config.fromfile('tests/data/config/a.py') + >>> cfg.filename + "/home/kchen/projects/mmcv/tests/data/config/a.py" + >>> cfg.item4 + 'test' + >>> cfg + "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " + "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" + """ + + @staticmethod + def _validate_py_syntax(filename): + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError('There are syntax errors in config ' + f'file {filename}: {e}') + + @staticmethod + def _substitute_predefined_vars(filename, temp_config_name): + file_dirname = osp.dirname(filename) + file_basename = osp.basename(filename) + file_basename_no_extension = osp.splitext(file_basename)[0] + file_extname = osp.splitext(filename)[1] + support_templates = dict( + fileDirname=file_dirname, + fileBasename=file_basename, + fileBasenameNoExtension=file_basename_no_extension, + fileExtname=file_extname) + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + for key, value in support_templates.items(): + regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' + value = value.replace('\\', '/') + config_file = re.sub(regexp, value, config_file) + with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: + tmp_config_file.write(config_file) + + @staticmethod + def _pre_substitute_base_vars(filename, temp_config_name): + """Substitute base variable placehoders to string, so that parsing + would work.""" + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + base_var_dict = {} + regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}' + base_vars = set(re.findall(regexp, config_file)) + for base_var in base_vars: + randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}' + base_var_dict[randstr] = base_var + regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}' + config_file = re.sub(regexp, f'"{randstr}"', config_file) + with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: + tmp_config_file.write(config_file) + return base_var_dict + + @staticmethod + def _substitute_base_vars(cfg, base_var_dict, base_cfg): + """Substitute variable strings to their actual values.""" + cfg = copy.deepcopy(cfg) + + if isinstance(cfg, dict): + for k, v in cfg.items(): + if isinstance(v, str) and v in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[v].split('.'): + new_v = new_v[new_k] + cfg[k] = new_v + elif isinstance(v, (list, tuple, dict)): + cfg[k] = Config._substitute_base_vars( + v, base_var_dict, base_cfg) + elif isinstance(cfg, tuple): + cfg = tuple( + Config._substitute_base_vars(c, base_var_dict, base_cfg) + for c in cfg) + elif isinstance(cfg, list): + cfg = [ + Config._substitute_base_vars(c, base_var_dict, base_cfg) + for c in cfg + ] + elif isinstance(cfg, str) and cfg in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[cfg].split('.'): + new_v = new_v[new_k] + cfg = new_v + + return cfg + + @staticmethod + def _file2dict(filename, use_predefined_variables=True): + filename = osp.abspath(osp.expanduser(filename)) + check_file_exist(filename) + fileExtname = osp.splitext(filename)[1] + if fileExtname not in ['.py', '.json', '.yaml', '.yml']: + raise IOError('Only py/yml/yaml/json type are supported now!') + + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile( + dir=temp_config_dir, suffix=fileExtname) + if platform.system() == 'Windows': + temp_config_file.close() + temp_config_name = osp.basename(temp_config_file.name) + # Substitute predefined variables + if use_predefined_variables: + Config._substitute_predefined_vars(filename, + temp_config_file.name) + else: + shutil.copyfile(filename, temp_config_file.name) + # Substitute base variables from placeholders to strings + base_var_dict = Config._pre_substitute_base_vars( + temp_config_file.name, temp_config_file.name) + + if filename.endswith('.py'): + temp_module_name = osp.splitext(temp_config_name)[0] + sys.path.insert(0, temp_config_dir) + Config._validate_py_syntax(filename) + mod = import_module(temp_module_name) + sys.path.pop(0) + cfg_dict = { + name: value + for name, value in mod.__dict__.items() + if not name.startswith('__') + } + # delete imported module + del sys.modules[temp_module_name] + elif filename.endswith(('.yml', '.yaml', '.json')): + import annotator.uniformer.mmcv as mmcv + cfg_dict = mmcv.load(temp_config_file.name) + # close temp file + temp_config_file.close() + + # check deprecation information + if DEPRECATION_KEY in cfg_dict: + deprecation_info = cfg_dict.pop(DEPRECATION_KEY) + warning_msg = f'The config file {filename} will be deprecated ' \ + 'in the future.' + if 'expected' in deprecation_info: + warning_msg += f' Please use {deprecation_info["expected"]} ' \ + 'instead.' + if 'reference' in deprecation_info: + warning_msg += ' More information can be found at ' \ + f'{deprecation_info["reference"]}' + warnings.warn(warning_msg) + + cfg_text = filename + '\n' + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + cfg_text += f.read() + + if BASE_KEY in cfg_dict: + cfg_dir = osp.dirname(filename) + base_filename = cfg_dict.pop(BASE_KEY) + base_filename = base_filename if isinstance( + base_filename, list) else [base_filename] + + cfg_dict_list = list() + cfg_text_list = list() + for f in base_filename: + _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) + cfg_dict_list.append(_cfg_dict) + cfg_text_list.append(_cfg_text) + + base_cfg_dict = dict() + for c in cfg_dict_list: + duplicate_keys = base_cfg_dict.keys() & c.keys() + if len(duplicate_keys) > 0: + raise KeyError('Duplicate key is not allowed among bases. ' + f'Duplicate keys: {duplicate_keys}') + base_cfg_dict.update(c) + + # Substitute base variables from strings to their actual values + cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, + base_cfg_dict) + + base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) + cfg_dict = base_cfg_dict + + # merge cfg_text + cfg_text_list.append(cfg_text) + cfg_text = '\n'.join(cfg_text_list) + + return cfg_dict, cfg_text + + @staticmethod + def _merge_a_into_b(a, b, allow_list_keys=False): + """merge dict ``a`` into dict ``b`` (non-inplace). + + Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid + in-place modifications. + + Args: + a (dict): The source dict to be merged into ``b``. + b (dict): The origin dict to be fetch keys from ``a``. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in source ``a`` and will replace the element of the + corresponding index in b if b is a list. Default: False. + + Returns: + dict: The modified dict of ``b`` using ``a``. + + Examples: + # Normally merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # Delete b first and merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # b is a list + >>> Config._merge_a_into_b( + ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) + [{'a': 2}, {'b': 2}] + """ + b = b.copy() + for k, v in a.items(): + if allow_list_keys and k.isdigit() and isinstance(b, list): + k = int(k) + if len(b) <= k: + raise KeyError(f'Index {k} exceeds the length of list {b}') + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + elif isinstance(v, + dict) and k in b and not v.pop(DELETE_KEY, False): + allowed_types = (dict, list) if allow_list_keys else dict + if not isinstance(b[k], allowed_types): + raise TypeError( + f'{k}={v} in child config cannot inherit from base ' + f'because {k} is a dict in the child config but is of ' + f'type {type(b[k])} in base config. You may set ' + f'`{DELETE_KEY}=True` to ignore the base config') + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + else: + b[k] = v + return b + + @staticmethod + def fromfile(filename, + use_predefined_variables=True, + import_custom_modules=True): + cfg_dict, cfg_text = Config._file2dict(filename, + use_predefined_variables) + if import_custom_modules and cfg_dict.get('custom_imports', None): + import_modules_from_strings(**cfg_dict['custom_imports']) + return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + + @staticmethod + def fromstring(cfg_str, file_format): + """Generate config from config str. + + Args: + cfg_str (str): Config str. + file_format (str): Config file format corresponding to the + config str. Only py/yml/yaml/json type are supported now! + + Returns: + obj:`Config`: Config obj. + """ + if file_format not in ['.py', '.json', '.yaml', '.yml']: + raise IOError('Only py/yml/yaml/json type are supported now!') + if file_format != '.py' and 'dict(' in cfg_str: + # check if users specify a wrong suffix for python + warnings.warn( + 'Please check "file_format", the file format may be .py') + with tempfile.NamedTemporaryFile( + 'w', encoding='utf-8', suffix=file_format, + delete=False) as temp_file: + temp_file.write(cfg_str) + # on windows, previous implementation cause error + # see PR 1077 for details + cfg = Config.fromfile(temp_file.name) + os.remove(temp_file.name) + return cfg + + @staticmethod + def auto_argparser(description=None): + """Generate argparser from config file automatically (experimental)""" + partial_parser = ArgumentParser(description=description) + partial_parser.add_argument('config', help='config file path') + cfg_file = partial_parser.parse_known_args()[0].config + cfg = Config.fromfile(cfg_file) + parser = ArgumentParser(description=description) + parser.add_argument('config', help='config file path') + add_args(parser, cfg) + return parser, cfg + + def __init__(self, cfg_dict=None, cfg_text=None, filename=None): + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError('cfg_dict must be a dict, but ' + f'got {type(cfg_dict)}') + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f'{key} is reserved for config file') + + super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) + super(Config, self).__setattr__('_filename', filename) + if cfg_text: + text = cfg_text + elif filename: + with open(filename, 'r') as f: + text = f.read() + else: + text = '' + super(Config, self).__setattr__('_text', text) + + @property + def filename(self): + return self._filename + + @property + def text(self): + return self._text + + @property + def pretty_text(self): + + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split('\n') + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = f"'{v}'" + else: + v_str = str(v) + + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + + return attr_str + + def _format_list(k, v, use_mapping=False): + # check if all items in the list are dict + if all(isinstance(_, dict) for _ in v): + v_str = '[\n' + v_str += '\n'.join( + f'dict({_indent(_format_dict(v_), indent)}),' + for v_ in v).rstrip(',') + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + ']' + else: + attr_str = _format_basic_types(k, v, use_mapping) + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= \ + (not str(key_name).isidentifier()) + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = '' + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += '{' + for idx, (k, v) in enumerate(input_dict.items()): + is_last = idx >= len(input_dict) - 1 + end = '' if outest_level or is_last else ',' + if isinstance(v, dict): + v_str = '\n' + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: dict({v_str}' + else: + attr_str = f'{str(k)}=dict({v_str}' + attr_str = _indent(attr_str, indent) + ')' + end + elif isinstance(v, list): + attr_str = _format_list(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += '\n'.join(s) + if use_mapping: + r += '}' + return r + + cfg_dict = self._cfg_dict.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + # copied from setup.cfg + yapf_style = dict( + based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) + text, _ = FormatCode(text, style_config=yapf_style, verify=True) + + return text + + def __repr__(self): + return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' + + def __len__(self): + return len(self._cfg_dict) + + def __getattr__(self, name): + return getattr(self._cfg_dict, name) + + def __getitem__(self, name): + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self): + return iter(self._cfg_dict) + + def __getstate__(self): + return (self._cfg_dict, self._filename, self._text) + + def __setstate__(self, state): + _cfg_dict, _filename, _text = state + super(Config, self).__setattr__('_cfg_dict', _cfg_dict) + super(Config, self).__setattr__('_filename', _filename) + super(Config, self).__setattr__('_text', _text) + + def dump(self, file=None): + cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() + if self.filename.endswith('.py'): + if file is None: + return self.pretty_text + else: + with open(file, 'w', encoding='utf-8') as f: + f.write(self.pretty_text) + else: + import annotator.uniformer.mmcv as mmcv + if file is None: + file_format = self.filename.split('.')[-1] + return mmcv.dump(cfg_dict, file_format=file_format) + else: + mmcv.dump(cfg_dict, file) + + def merge_from_dict(self, options, allow_list_keys=True): + """Merge list into cfg_dict. + + Merge the dict parsed by MultipleKVAction into this cfg. + + Examples: + >>> options = {'model.backbone.depth': 50, + ... 'model.backbone.with_cp':True} + >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) + >>> cfg.merge_from_dict(options) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict( + ... model=dict(backbone=dict(depth=50, with_cp=True))) + + # Merge list element + >>> cfg = Config(dict(pipeline=[ + ... dict(type='LoadImage'), dict(type='LoadAnnotations')])) + >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) + >>> cfg.merge_from_dict(options, allow_list_keys=True) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict(pipeline=[ + ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')]) + + Args: + options (dict): dict of configs to merge from. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in ``options`` and will replace the element of the + corresponding index in the config if the config is a list. + Default: True. + """ + option_cfg_dict = {} + for full_key, v in options.items(): + d = option_cfg_dict + key_list = full_key.split('.') + for subkey in key_list[:-1]: + d.setdefault(subkey, ConfigDict()) + d = d[subkey] + subkey = key_list[-1] + d[subkey] = v + + cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + super(Config, self).__setattr__( + '_cfg_dict', + Config._merge_a_into_b( + option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) + + +class DictAction(Action): + """ + argparse action to split an argument into KEY=VALUE form + on the first = and append to a dictionary. List options can + be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit + brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build + list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' + """ + + @staticmethod + def _parse_int_float_bool(val): + try: + return int(val) + except ValueError: + pass + try: + return float(val) + except ValueError: + pass + if val.lower() in ['true', 'false']: + return True if val.lower() == 'true' else False + return val + + @staticmethod + def _parse_iterable(val): + """Parse iterable values in the string. + + All elements inside '()' or '[]' are treated as iterable values. + + Args: + val (str): Value string. + + Returns: + list | tuple: The expanded list or tuple from the string. + + Examples: + >>> DictAction._parse_iterable('1,2,3') + [1, 2, 3] + >>> DictAction._parse_iterable('[a, b, c]') + ['a', 'b', 'c'] + >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') + [(1, 2, 3), ['a', 'b'], 'c'] + """ + + def find_next_comma(string): + """Find the position of next comma in the string. + + If no ',' is found in the string, return the string length. All + chars inside '()' and '[]' are treated as one element and thus ',' + inside these brackets are ignored. + """ + assert (string.count('(') == string.count(')')) and ( + string.count('[') == string.count(']')), \ + f'Imbalanced brackets exist in {string}' + end = len(string) + for idx, char in enumerate(string): + pre = string[:idx] + # The string before this ',' is balanced + if ((char == ',') and (pre.count('(') == pre.count(')')) + and (pre.count('[') == pre.count(']'))): + end = idx + break + return end + + # Strip ' and " characters and replace whitespace. + val = val.strip('\'\"').replace(' ', '') + is_tuple = False + if val.startswith('(') and val.endswith(')'): + is_tuple = True + val = val[1:-1] + elif val.startswith('[') and val.endswith(']'): + val = val[1:-1] + elif ',' not in val: + # val is a single value + return DictAction._parse_int_float_bool(val) + + values = [] + while len(val) > 0: + comma_idx = find_next_comma(val) + element = DictAction._parse_iterable(val[:comma_idx]) + values.append(element) + val = val[comma_idx + 1:] + if is_tuple: + values = tuple(values) + return values + + def __call__(self, parser, namespace, values, option_string=None): + options = {} + for kv in values: + key, val = kv.split('=', maxsplit=1) + options[key] = self._parse_iterable(val) + setattr(namespace, self.dest, options) diff --git a/annotator/uniformer/mmcv/utils/env.py b/annotator/uniformer/mmcv/utils/env.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f0d92529e193e6d8339419bcd9bed7901a7769 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/env.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This file holding some environment constant for sharing by other files.""" + +import os.path as osp +import subprocess +import sys +from collections import defaultdict + +import cv2 +import torch + +import annotator.uniformer.mmcv as mmcv +from .parrots_wrapper import get_build_config + + +def collect_env(): + """Collect the information of the running environments. + + Returns: + dict: The environment information. The following fields are contained. + + - sys.platform: The variable of ``sys.platform``. + - Python: Python version. + - CUDA available: Bool, indicating if CUDA is available. + - GPU devices: Device type of each GPU. + - CUDA_HOME (optional): The env var ``CUDA_HOME``. + - NVCC (optional): NVCC version. + - GCC: GCC version, "n/a" if GCC is not installed. + - PyTorch: PyTorch version. + - PyTorch compiling details: The output of \ + ``torch.__config__.show()``. + - TorchVision (optional): TorchVision version. + - OpenCV: OpenCV version. + - MMCV: MMCV version. + - MMCV Compiler: The GCC version for compiling MMCV ops. + - MMCV CUDA Compiler: The CUDA version for compiling MMCV ops. + """ + env_info = {} + env_info['sys.platform'] = sys.platform + env_info['Python'] = sys.version.replace('\n', '') + + cuda_available = torch.cuda.is_available() + env_info['CUDA available'] = cuda_available + + if cuda_available: + devices = defaultdict(list) + for k in range(torch.cuda.device_count()): + devices[torch.cuda.get_device_name(k)].append(str(k)) + for name, device_ids in devices.items(): + env_info['GPU ' + ','.join(device_ids)] = name + + from annotator.uniformer.mmcv.utils.parrots_wrapper import _get_cuda_home + CUDA_HOME = _get_cuda_home() + env_info['CUDA_HOME'] = CUDA_HOME + + if CUDA_HOME is not None and osp.isdir(CUDA_HOME): + try: + nvcc = osp.join(CUDA_HOME, 'bin/nvcc') + nvcc = subprocess.check_output( + f'"{nvcc}" -V | tail -n1', shell=True) + nvcc = nvcc.decode('utf-8').strip() + except subprocess.SubprocessError: + nvcc = 'Not Available' + env_info['NVCC'] = nvcc + + try: + gcc = subprocess.check_output('gcc --version | head -n1', shell=True) + gcc = gcc.decode('utf-8').strip() + env_info['GCC'] = gcc + except subprocess.CalledProcessError: # gcc is unavailable + env_info['GCC'] = 'n/a' + + env_info['PyTorch'] = torch.__version__ + env_info['PyTorch compiling details'] = get_build_config() + + try: + import torchvision + env_info['TorchVision'] = torchvision.__version__ + except ModuleNotFoundError: + pass + + env_info['OpenCV'] = cv2.__version__ + + env_info['MMCV'] = mmcv.__version__ + + try: + from annotator.uniformer.mmcv.ops import get_compiler_version, get_compiling_cuda_version + except ModuleNotFoundError: + env_info['MMCV Compiler'] = 'n/a' + env_info['MMCV CUDA Compiler'] = 'n/a' + else: + env_info['MMCV Compiler'] = get_compiler_version() + env_info['MMCV CUDA Compiler'] = get_compiling_cuda_version() + + return env_info diff --git a/annotator/uniformer/mmcv/utils/ext_loader.py b/annotator/uniformer/mmcv/utils/ext_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..08132d2c1b9a1c28880e4bab4d4fa1ba39d9d083 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/ext_loader.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import os +import pkgutil +import warnings +from collections import namedtuple + +import torch + +if torch.__version__ != 'parrots': + + def load_ext(name, funcs): + ext = importlib.import_module('mmcv.' + name) + for fun in funcs: + assert hasattr(ext, fun), f'{fun} miss in module {name}' + return ext +else: + from parrots import extension + from parrots.base import ParrotsException + + has_return_value_ops = [ + 'nms', + 'softnms', + 'nms_match', + 'nms_rotated', + 'top_pool_forward', + 'top_pool_backward', + 'bottom_pool_forward', + 'bottom_pool_backward', + 'left_pool_forward', + 'left_pool_backward', + 'right_pool_forward', + 'right_pool_backward', + 'fused_bias_leakyrelu', + 'upfirdn2d', + 'ms_deform_attn_forward', + 'pixel_group', + 'contour_expand', + ] + + def get_fake_func(name, e): + + def fake_func(*args, **kwargs): + warnings.warn(f'{name} is not supported in parrots now') + raise e + + return fake_func + + def load_ext(name, funcs): + ExtModule = namedtuple('ExtModule', funcs) + ext_list = [] + lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + for fun in funcs: + try: + ext_fun = extension.load(fun, name, lib_dir=lib_root) + except ParrotsException as e: + if 'No element registered' not in e.message: + warnings.warn(e.message) + ext_fun = get_fake_func(fun, e) + ext_list.append(ext_fun) + else: + if fun in has_return_value_ops: + ext_list.append(ext_fun.op) + else: + ext_list.append(ext_fun.op_) + return ExtModule(*ext_list) + + +def check_ops_exist(): + ext_loader = pkgutil.find_loader('mmcv._ext') + return ext_loader is not None diff --git a/annotator/uniformer/mmcv/utils/logging.py b/annotator/uniformer/mmcv/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa0e04bb9b3ab2a4bfbc4def50404ccbac2c6e6 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/logging.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +import torch.distributed as dist + +logger_initialized = {} + + +def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified and the process rank is 0, a FileHandler + will also be added. + + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + file_mode (str): The file mode used in opening log file. + Defaults to 'w'. + + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + # handle duplicate logs to the console + # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) + # to the root logger. As logger.propagate is True by default, this root + # level handler causes logging messages from rank>0 processes to + # unexpectedly show up on the console, creating much unwanted clutter. + # To fix this issue, we set the root logger's StreamHandler, if any, to log + # at the ERROR level. + for handler in logger.root.handlers: + if type(handler) is logging.StreamHandler: + handler.setLevel(logging.ERROR) + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # only rank 0 will add a FileHandler + if rank == 0 and log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + if rank == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + + logger_initialized[name] = True + + return logger + + +def print_log(msg, logger=None, level=logging.INFO): + """Print a log message. + + Args: + msg (str): The message to be logged. + logger (logging.Logger | str | None): The logger to be used. + Some special loggers are: + - "silent": no message will be printed. + - other str: the logger obtained with `get_root_logger(logger)`. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger == 'silent': + pass + elif isinstance(logger, str): + _logger = get_logger(logger) + _logger.log(level, msg) + else: + raise TypeError( + 'logger should be either a logging.Logger object, str, ' + f'"silent" or None, but got {type(logger)}') diff --git a/annotator/uniformer/mmcv/utils/misc.py b/annotator/uniformer/mmcv/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..2c58d0d7fee9fe3d4519270ad8c1e998d0d8a18c --- /dev/null +++ b/annotator/uniformer/mmcv/utils/misc.py @@ -0,0 +1,377 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections.abc +import functools +import itertools +import subprocess +import warnings +from collections import abc +from importlib import import_module +from inspect import getfullargspec +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def is_str(x): + """Whether the input is an string instance. + + Note: This method is deprecated since python 2 is no longer supported. + """ + return isinstance(x, str) + + +def import_modules_from_strings(imports, allow_failed_imports=False): + """Import modules from the given list of strings. + + Args: + imports (list | str | None): The given module names to be imported. + allow_failed_imports (bool): If True, the failed imports will return + None. Otherwise, an ImportError is raise. Default: False. + + Returns: + list[module] | module | None: The imported modules. + + Examples: + >>> osp, sys = import_modules_from_strings( + ... ['os.path', 'sys']) + >>> import os.path as osp_ + >>> import sys as sys_ + >>> assert osp == osp_ + >>> assert sys == sys_ + """ + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError( + f'custom_imports must be a list but got type {type(imports)}') + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError( + f'{imp} is of type {type(imp)} and cannot be imported.') + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + warnings.warn(f'{imp} failed to import and is ignored.', + UserWarning) + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported + + +def iter_cast(inputs, dst_type, return_type=None): + """Cast elements of an iterable object into some type. + + Args: + inputs (Iterable): The input object. + dst_type (type): Destination type. + return_type (type, optional): If specified, the output object will be + converted to this type, otherwise an iterator. + + Returns: + iterator or specified type: The converted object. + """ + if not isinstance(inputs, abc.Iterable): + raise TypeError('inputs must be an iterable object') + if not isinstance(dst_type, type): + raise TypeError('"dst_type" must be a valid type') + + out_iterable = map(dst_type, inputs) + + if return_type is None: + return out_iterable + else: + return return_type(out_iterable) + + +def list_cast(inputs, dst_type): + """Cast elements of an iterable object into a list of some type. + + A partial method of :func:`iter_cast`. + """ + return iter_cast(inputs, dst_type, return_type=list) + + +def tuple_cast(inputs, dst_type): + """Cast elements of an iterable object into a tuple of some type. + + A partial method of :func:`iter_cast`. + """ + return iter_cast(inputs, dst_type, return_type=tuple) + + +def is_seq_of(seq, expected_type, seq_type=None): + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def is_list_of(seq, expected_type): + """Check whether it is a list of some type. + + A partial method of :func:`is_seq_of`. + """ + return is_seq_of(seq, expected_type, seq_type=list) + + +def is_tuple_of(seq, expected_type): + """Check whether it is a tuple of some type. + + A partial method of :func:`is_seq_of`. + """ + return is_seq_of(seq, expected_type, seq_type=tuple) + + +def slice_list(in_list, lens): + """Slice a list into several sub lists by a list of given length. + + Args: + in_list (list): The list to be sliced. + lens(int or list): The expected length of each out list. + + Returns: + list: A list of sliced list. + """ + if isinstance(lens, int): + assert len(in_list) % lens == 0 + lens = [lens] * int(len(in_list) / lens) + if not isinstance(lens, list): + raise TypeError('"indices" must be an integer or a list of integers') + elif sum(lens) != len(in_list): + raise ValueError('sum of lens and list length does not ' + f'match: {sum(lens)} != {len(in_list)}') + out_list = [] + idx = 0 + for i in range(len(lens)): + out_list.append(in_list[idx:idx + lens[i]]) + idx += lens[i] + return out_list + + +def concat_list(in_list): + """Concatenate a list of list into a single list. + + Args: + in_list (list): The list of list to be merged. + + Returns: + list: The concatenated flat list. + """ + return list(itertools.chain(*in_list)) + + +def check_prerequisites( + prerequisites, + checker, + msg_tmpl='Prerequisites "{}" are required in method "{}" but not ' + 'found, please install them first.'): # yapf: disable + """A decorator factory to check if prerequisites are satisfied. + + Args: + prerequisites (str of list[str]): Prerequisites to be checked. + checker (callable): The checker method that returns True if a + prerequisite is meet, False otherwise. + msg_tmpl (str): The message template with two variables. + + Returns: + decorator: A specific decorator. + """ + + def wrap(func): + + @functools.wraps(func) + def wrapped_func(*args, **kwargs): + requirements = [prerequisites] if isinstance( + prerequisites, str) else prerequisites + missing = [] + for item in requirements: + if not checker(item): + missing.append(item) + if missing: + print(msg_tmpl.format(', '.join(missing), func.__name__)) + raise RuntimeError('Prerequisites not meet.') + else: + return func(*args, **kwargs) + + return wrapped_func + + return wrap + + +def _check_py_package(package): + try: + import_module(package) + except ImportError: + return False + else: + return True + + +def _check_executable(cmd): + if subprocess.call(f'which {cmd}', shell=True) != 0: + return False + else: + return True + + +def requires_package(prerequisites): + """A decorator to check if some python packages are installed. + + Example: + >>> @requires_package('numpy') + >>> func(arg1, args): + >>> return numpy.zeros(1) + array([0.]) + >>> @requires_package(['numpy', 'non_package']) + >>> func(arg1, args): + >>> return numpy.zeros(1) + ImportError + """ + return check_prerequisites(prerequisites, checker=_check_py_package) + + +def requires_executable(prerequisites): + """A decorator to check if some executable files are installed. + + Example: + >>> @requires_executable('ffmpeg') + >>> func(arg1, args): + >>> print(1) + 1 + """ + return check_prerequisites(prerequisites, checker=_check_executable) + + +def deprecated_api_warning(name_dict, cls_name=None): + """A decorator to check if some arguments are deprecate and try to replace + deprecate src_arg_name to dst_arg_name. + + Args: + name_dict(dict): + key (str): Deprecate argument names. + val (str): Expected argument names. + + Returns: + func: New function. + """ + + def api_warning_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get name of the function + func_name = old_func.__name__ + if cls_name is not None: + func_name = f'{cls_name}.{func_name}' + if args: + arg_names = args_info.args[:len(args)] + for src_arg_name, dst_arg_name in name_dict.items(): + if src_arg_name in arg_names: + warnings.warn( + f'"{src_arg_name}" is deprecated in ' + f'`{func_name}`, please use "{dst_arg_name}" ' + 'instead') + arg_names[arg_names.index(src_arg_name)] = dst_arg_name + if kwargs: + for src_arg_name, dst_arg_name in name_dict.items(): + if src_arg_name in kwargs: + + assert dst_arg_name not in kwargs, ( + f'The expected behavior is to replace ' + f'the deprecated key `{src_arg_name}` to ' + f'new key `{dst_arg_name}`, but got them ' + f'in the arguments at the same time, which ' + f'is confusing. `{src_arg_name} will be ' + f'deprecated in the future, please ' + f'use `{dst_arg_name}` instead.') + + warnings.warn( + f'"{src_arg_name}" is deprecated in ' + f'`{func_name}`, please use "{dst_arg_name}" ' + 'instead') + kwargs[dst_arg_name] = kwargs.pop(src_arg_name) + + # apply converted arguments to the decorated method + output = old_func(*args, **kwargs) + return output + + return new_func + + return api_warning_wrapper + + +def is_method_overridden(method, base_class, derived_class): + """Check if a method of base class is overridden in derived class. + + Args: + method (str): the method name to check. + base_class (type): the class of the base class. + derived_class (type | Any): the class or instance of the derived class. + """ + assert isinstance(base_class, type), \ + "base_class doesn't accept instance, Please pass class instead." + + if not isinstance(derived_class, type): + derived_class = derived_class.__class__ + + base_method = getattr(base_class, method) + derived_method = getattr(derived_class, method) + return derived_method != base_method + + +def has_method(obj: object, method: str) -> bool: + """Check whether the object has a method. + + Args: + method (str): The method name to check. + obj (object): The object to check. + + Returns: + bool: True if the object has the method else False. + """ + return hasattr(obj, method) and callable(getattr(obj, method)) diff --git a/annotator/uniformer/mmcv/utils/parrots_jit.py b/annotator/uniformer/mmcv/utils/parrots_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..61873f6dbb9b10ed972c90aa8faa321e3cb3249e --- /dev/null +++ b/annotator/uniformer/mmcv/utils/parrots_jit.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +from .parrots_wrapper import TORCH_VERSION + +parrots_jit_option = os.getenv('PARROTS_JIT_OPTION') + +if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON': + from parrots.jit import pat as jit +else: + + def jit(func=None, + check_input=None, + full_shape=True, + derivate=False, + coderize=False, + optimize=False): + + def wrapper(func): + + def wrapper_inner(*args, **kargs): + return func(*args, **kargs) + + return wrapper_inner + + if func is None: + return wrapper + else: + return func + + +if TORCH_VERSION == 'parrots': + from parrots.utils.tester import skip_no_elena +else: + + def skip_no_elena(func): + + def wrapper(*args, **kargs): + return func(*args, **kargs) + + return wrapper diff --git a/annotator/uniformer/mmcv/utils/parrots_wrapper.py b/annotator/uniformer/mmcv/utils/parrots_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..93c97640d4b9ed088ca82cfe03e6efebfcfa9dbf --- /dev/null +++ b/annotator/uniformer/mmcv/utils/parrots_wrapper.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial + +import torch + +TORCH_VERSION = torch.__version__ + + +def is_rocm_pytorch() -> bool: + is_rocm = False + if TORCH_VERSION != 'parrots': + try: + from torch.utils.cpp_extension import ROCM_HOME + is_rocm = True if ((torch.version.hip is not None) and + (ROCM_HOME is not None)) else False + except ImportError: + pass + return is_rocm + + +def _get_cuda_home(): + if TORCH_VERSION == 'parrots': + from parrots.utils.build_extension import CUDA_HOME + else: + if is_rocm_pytorch(): + from torch.utils.cpp_extension import ROCM_HOME + CUDA_HOME = ROCM_HOME + else: + from torch.utils.cpp_extension import CUDA_HOME + return CUDA_HOME + + +def get_build_config(): + if TORCH_VERSION == 'parrots': + from parrots.config import get_build_info + return get_build_info() + else: + return torch.__config__.show() + + +def _get_conv(): + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin + else: + from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin + return _ConvNd, _ConvTransposeMixin + + +def _get_dataloader(): + if TORCH_VERSION == 'parrots': + from torch.utils.data import DataLoader, PoolDataLoader + else: + from torch.utils.data import DataLoader + PoolDataLoader = DataLoader + return DataLoader, PoolDataLoader + + +def _get_extension(): + if TORCH_VERSION == 'parrots': + from parrots.utils.build_extension import BuildExtension, Extension + CppExtension = partial(Extension, cuda=False) + CUDAExtension = partial(Extension, cuda=True) + else: + from torch.utils.cpp_extension import (BuildExtension, CppExtension, + CUDAExtension) + return BuildExtension, CppExtension, CUDAExtension + + +def _get_pool(): + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd, _AvgPoolNd, + _MaxPoolNd) + else: + from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, + _AdaptiveMaxPoolNd, _AvgPoolNd, + _MaxPoolNd) + return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd + + +def _get_norm(): + if TORCH_VERSION == 'parrots': + from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm + SyncBatchNorm_ = torch.nn.SyncBatchNorm2d + else: + from torch.nn.modules.instancenorm import _InstanceNorm + from torch.nn.modules.batchnorm import _BatchNorm + SyncBatchNorm_ = torch.nn.SyncBatchNorm + return _BatchNorm, _InstanceNorm, SyncBatchNorm_ + + +_ConvNd, _ConvTransposeMixin = _get_conv() +DataLoader, PoolDataLoader = _get_dataloader() +BuildExtension, CppExtension, CUDAExtension = _get_extension() +_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() +_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() + + +class SyncBatchNorm(SyncBatchNorm_): + + def _check_input_dim(self, input): + if TORCH_VERSION == 'parrots': + if input.dim() < 2: + raise ValueError( + f'expected at least 2D input (got {input.dim()}D input)') + else: + super()._check_input_dim(input) diff --git a/annotator/uniformer/mmcv/utils/path.py b/annotator/uniformer/mmcv/utils/path.py new file mode 100644 index 0000000000000000000000000000000000000000..7dab4b3041413b1432b0f434b8b14783097d33c6 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/path.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from pathlib import Path + +from .misc import is_str + + +def is_filepath(x): + return is_str(x) or isinstance(x, Path) + + +def fopen(filepath, *args, **kwargs): + if is_str(filepath): + return open(filepath, *args, **kwargs) + elif isinstance(filepath, Path): + return filepath.open(*args, **kwargs) + raise ValueError('`filepath` should be a string or a Path') + + +def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): + if not osp.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == '': + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def symlink(src, dst, overwrite=True, **kwargs): + if os.path.lexists(dst) and overwrite: + os.remove(dst) + os.symlink(src, dst, **kwargs) + + +def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): + """Scan a directory to find the interested files. + + Args: + dir_path (str | obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + case_sensitive (bool, optional) : If set to False, ignore the case of + suffix. Default: True. + + Returns: + A generator for all the interested files with relative paths. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + if suffix is not None and not case_sensitive: + suffix = suffix.lower() if isinstance(suffix, str) else tuple( + item.lower() for item in suffix) + + root = dir_path + + def _scandir(dir_path, suffix, recursive, case_sensitive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + _rel_path = rel_path if case_sensitive else rel_path.lower() + if suffix is None or _rel_path.endswith(suffix): + yield rel_path + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir(entry.path, suffix, recursive, + case_sensitive) + + return _scandir(dir_path, suffix, recursive, case_sensitive) + + +def find_vcs_root(path, markers=('.git', )): + """Finds the root directory (including itself) of specified markers. + + Args: + path (str): Path of directory or file. + markers (list[str], optional): List of file or directory names. + + Returns: + The directory contained one of the markers or None if not found. + """ + if osp.isfile(path): + path = osp.dirname(path) + + prev, cur = None, osp.abspath(osp.expanduser(path)) + while cur != prev: + if any(osp.exists(osp.join(cur, marker)) for marker in markers): + return cur + prev, cur = cur, osp.split(cur)[0] + return None diff --git a/annotator/uniformer/mmcv/utils/progressbar.py b/annotator/uniformer/mmcv/utils/progressbar.py new file mode 100644 index 0000000000000000000000000000000000000000..0062f670dd94fa9da559ab26ef85517dcf5211c7 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/progressbar.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from collections.abc import Iterable +from multiprocessing import Pool +from shutil import get_terminal_size + +from .timer import Timer + + +class ProgressBar: + """A progress bar which can print the progress.""" + + def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): + self.task_num = task_num + self.bar_width = bar_width + self.completed = 0 + self.file = file + if start: + self.start() + + @property + def terminal_width(self): + width, _ = get_terminal_size() + return width + + def start(self): + if self.task_num > 0: + self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, ' + 'elapsed: 0s, ETA:') + else: + self.file.write('completed: 0, elapsed: 0s') + self.file.flush() + self.timer = Timer() + + def update(self, num_tasks=1): + assert num_tasks > 0 + self.completed += num_tasks + elapsed = self.timer.since_start() + if elapsed > 0: + fps = self.completed / elapsed + else: + fps = float('inf') + if self.task_num > 0: + percentage = self.completed / float(self.task_num) + eta = int(elapsed * (1 - percentage) / percentage + 0.5) + msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \ + f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \ + f'ETA: {eta:5}s' + + bar_width = min(self.bar_width, + int(self.terminal_width - len(msg)) + 2, + int(self.terminal_width * 0.6)) + bar_width = max(2, bar_width) + mark_width = int(bar_width * percentage) + bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width) + self.file.write(msg.format(bar_chars)) + else: + self.file.write( + f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' + f' {fps:.1f} tasks/s') + self.file.flush() + + +def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs): + """Track the progress of tasks execution with a progress bar. + + Tasks are done with a simple for-loop. + + Args: + func (callable): The function to be applied to each task. + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + bar_width (int): Width of progress bar. + + Returns: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + prog_bar = ProgressBar(task_num, bar_width, file=file) + results = [] + for task in tasks: + results.append(func(task, **kwargs)) + prog_bar.update() + prog_bar.file.write('\n') + return results + + +def init_pool(process_num, initializer=None, initargs=None): + if initializer is None: + return Pool(process_num) + elif initargs is None: + return Pool(process_num, initializer) + else: + if not isinstance(initargs, tuple): + raise TypeError('"initargs" must be a tuple') + return Pool(process_num, initializer, initargs) + + +def track_parallel_progress(func, + tasks, + nproc, + initializer=None, + initargs=None, + bar_width=50, + chunksize=1, + skip_first=False, + keep_order=True, + file=sys.stdout): + """Track the progress of parallel task execution with a progress bar. + + The built-in :mod:`multiprocessing` module is used for process pools and + tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. + + Args: + func (callable): The function to be applied to each task. + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + nproc (int): Process (worker) number. + initializer (None or callable): Refer to :class:`multiprocessing.Pool` + for details. + initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for + details. + chunksize (int): Refer to :class:`multiprocessing.Pool` for details. + bar_width (int): Width of progress bar. + skip_first (bool): Whether to skip the first sample for each worker + when estimating fps, since the initialization step may takes + longer. + keep_order (bool): If True, :func:`Pool.imap` is used, otherwise + :func:`Pool.imap_unordered` is used. + + Returns: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + pool = init_pool(nproc, initializer, initargs) + start = not skip_first + task_num -= nproc * chunksize * int(skip_first) + prog_bar = ProgressBar(task_num, bar_width, start, file=file) + results = [] + if keep_order: + gen = pool.imap(func, tasks, chunksize) + else: + gen = pool.imap_unordered(func, tasks, chunksize) + for result in gen: + results.append(result) + if skip_first: + if len(results) < nproc * chunksize: + continue + elif len(results) == nproc * chunksize: + prog_bar.start() + continue + prog_bar.update() + prog_bar.file.write('\n') + pool.close() + pool.join() + return results + + +def track_iter_progress(tasks, bar_width=50, file=sys.stdout): + """Track the progress of tasks iteration or enumeration with a progress + bar. + + Tasks are yielded with a simple for-loop. + + Args: + tasks (list or tuple[Iterable, int]): A list of tasks or + (tasks, total num). + bar_width (int): Width of progress bar. + + Yields: + list: The task results. + """ + if isinstance(tasks, tuple): + assert len(tasks) == 2 + assert isinstance(tasks[0], Iterable) + assert isinstance(tasks[1], int) + task_num = tasks[1] + tasks = tasks[0] + elif isinstance(tasks, Iterable): + task_num = len(tasks) + else: + raise TypeError( + '"tasks" must be an iterable object or a (iterator, int) tuple') + prog_bar = ProgressBar(task_num, bar_width, file=file) + for task in tasks: + yield task + prog_bar.update() + prog_bar.file.write('\n') diff --git a/annotator/uniformer/mmcv/utils/registry.py b/annotator/uniformer/mmcv/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9df39bc9f3d8d568361e7250ab35468f2b74e0 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/registry.py @@ -0,0 +1,315 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import warnings +from functools import partial + +from .misc import is_seq_of + + +def build_from_cfg(cfg, registry, default_args=None): + """Build a module from config dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + if default_args is None or 'type' not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f'but got {cfg}\n{default_args}') + if not isinstance(registry, Registry): + raise TypeError('registry must be an mmcv.Registry object, ' + f'but got {type(registry)}') + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError('default_args must be a dict or None, ' + f'but got {type(default_args)}') + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop('type') + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError( + f'{obj_type} is not in the {registry.name} registry') + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f'{obj_cls.__name__}: {e}') + + +class Registry: + """A registry to map strings to classes. + + Registered object could be built from registry. + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(type='ResNet')) + + Please refer to + https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for + advanced usage. + + Args: + name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = self.__class__.__name__ + \ + f'(name={self._name}, ' \ + f'items={self._module_dict})' + return format_str + + @staticmethod + def infer_scope(): + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Example: + # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``mmdet``. + + + Returns: + scope (str): The inferred scope name. + """ + # inspect.stack() trace where this function is called, the index-2 + # indicates the frame where `infer_scope()` is called + filename = inspect.getmodule(inspect.stack()[2][0]).__name__ + split_filename = filename.split('.') + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + + The first scope will be split from key. + + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + + Return: + scope (str, None): The first scope. + key (str): The remaining key. + """ + split_index = key.find('.') + if split_index != -1: + return key[:split_index], key[split_index + 1:] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + + Args: + key (str): The class name in string format. + + Returns: + class: The corresponding class. + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + + Example: + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> @mmdet_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(type='mmdet.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert registry.scope not in self.children, \ + f'scope {registry.scope} exists in {self.name} registry' + self.children[registry.scope] = registry + + def _register_module(self, module_class, module_name=None, force=False): + if not inspect.isclass(module_class): + raise TypeError('module must be a class, ' + f'but got {type(module_class)}') + + if module_name is None: + module_name = module_class.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f'{name} is already registered ' + f'in {self.name}') + self._module_dict[name] = module_class + + def deprecated_register_module(self, cls=None, force=False): + warnings.warn( + 'The old API of register_module(module, force=False) ' + 'is deprecated and will be removed, please use the new API ' + 'register_module(name=None, force=False, module=None) instead.') + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + + Example: + >>> backbones = Registry('backbone') + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> class ResNet: + >>> pass + >>> backbones.register_module(ResNet) + + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + 'name must be either of None, an instance of str or a sequence' + f' of str, but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module( + module_class=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(cls): + self._register_module( + module_class=cls, module_name=name, force=force) + return cls + + return _register diff --git a/annotator/uniformer/mmcv/utils/testing.py b/annotator/uniformer/mmcv/utils/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..a27f936da8ec14bac18562ede0a79d476d82f797 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/testing.py @@ -0,0 +1,140 @@ +# Copyright (c) Open-MMLab. +import sys +from collections.abc import Iterable +from runpy import run_path +from shlex import split +from typing import Any, Dict, List +from unittest.mock import patch + + +def check_python_script(cmd): + """Run the python cmd script with `__main__`. The difference between + `os.system` is that, this function exectues code in the current process, so + that it can be tracked by coverage tools. Currently it supports two forms: + + - ./tests/data/scripts/hello.py zz + - python tests/data/scripts/hello.py zz + """ + args = split(cmd) + if args[0] == 'python': + args = args[1:] + with patch.object(sys, 'argv', args): + run_path(args[0], run_name='__main__') + + +def _any(judge_result): + """Since built-in ``any`` works only when the element of iterable is not + iterable, implement the function.""" + if not isinstance(judge_result, Iterable): + return judge_result + + try: + for element in judge_result: + if _any(element): + return True + except TypeError: + # Maybe encounter the case: torch.tensor(True) | torch.tensor(False) + if judge_result: + return True + return False + + +def assert_dict_contains_subset(dict_obj: Dict[Any, Any], + expected_subset: Dict[Any, Any]) -> bool: + """Check if the dict_obj contains the expected_subset. + + Args: + dict_obj (Dict[Any, Any]): Dict object to be checked. + expected_subset (Dict[Any, Any]): Subset expected to be contained in + dict_obj. + + Returns: + bool: Whether the dict_obj contains the expected_subset. + """ + + for key, value in expected_subset.items(): + if key not in dict_obj.keys() or _any(dict_obj[key] != value): + return False + return True + + +def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool: + """Check if attribute of class object is correct. + + Args: + obj (object): Class object to be checked. + expected_attrs (Dict[str, Any]): Dict of the expected attrs. + + Returns: + bool: Whether the attribute of class object is correct. + """ + for attr, value in expected_attrs.items(): + if not hasattr(obj, attr) or _any(getattr(obj, attr) != value): + return False + return True + + +def assert_dict_has_keys(obj: Dict[str, Any], + expected_keys: List[str]) -> bool: + """Check if the obj has all the expected_keys. + + Args: + obj (Dict[str, Any]): Object to be checked. + expected_keys (List[str]): Keys expected to contained in the keys of + the obj. + + Returns: + bool: Whether the obj has the expected keys. + """ + return set(expected_keys).issubset(set(obj.keys())) + + +def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool: + """Check if target_keys is equal to result_keys. + + Args: + result_keys (List[str]): Result keys to be checked. + target_keys (List[str]): Target keys to be checked. + + Returns: + bool: Whether target_keys is equal to result_keys. + """ + return set(result_keys) == set(target_keys) + + +def assert_is_norm_layer(module) -> bool: + """Check if the module is a norm layer. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: Whether the module is a norm layer. + """ + from .parrots_wrapper import _BatchNorm, _InstanceNorm + from torch.nn import GroupNorm, LayerNorm + norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) + return isinstance(module, norm_layer_candidates) + + +def assert_params_all_zeros(module) -> bool: + """Check if the parameters of the module is all zeros. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: Whether the parameters of the module is all zeros. + """ + weight_data = module.weight.data + is_weight_zero = weight_data.allclose( + weight_data.new_zeros(weight_data.size())) + + if hasattr(module, 'bias') and module.bias is not None: + bias_data = module.bias.data + is_bias_zero = bias_data.allclose( + bias_data.new_zeros(bias_data.size())) + else: + is_bias_zero = True + + return is_weight_zero and is_bias_zero diff --git a/annotator/uniformer/mmcv/utils/timer.py b/annotator/uniformer/mmcv/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..e3db7d497d8b374e18b5297e0a1d6eb186fd8cba --- /dev/null +++ b/annotator/uniformer/mmcv/utils/timer.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from time import time + + +class TimerError(Exception): + + def __init__(self, message): + self.message = message + super(TimerError, self).__init__(message) + + +class Timer: + """A flexible Timer class. + + :Example: + + >>> import time + >>> import annotator.uniformer.mmcv as mmcv + >>> with mmcv.Timer(): + >>> # simulate a code block that will run for 1s + >>> time.sleep(1) + 1.000 + >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'): + >>> # simulate a code block that will run for 1s + >>> time.sleep(1) + it takes 1.0 seconds + >>> timer = mmcv.Timer() + >>> time.sleep(0.5) + >>> print(timer.since_start()) + 0.500 + >>> time.sleep(0.5) + >>> print(timer.since_last_check()) + 0.500 + >>> print(timer.since_start()) + 1.000 + """ + + def __init__(self, start=True, print_tmpl=None): + self._is_running = False + self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}' + if start: + self.start() + + @property + def is_running(self): + """bool: indicate whether the timer is running""" + return self._is_running + + def __enter__(self): + self.start() + return self + + def __exit__(self, type, value, traceback): + print(self.print_tmpl.format(self.since_last_check())) + self._is_running = False + + def start(self): + """Start the timer.""" + if not self._is_running: + self._t_start = time() + self._is_running = True + self._t_last = time() + + def since_start(self): + """Total time since the timer is started. + + Returns (float): Time in seconds. + """ + if not self._is_running: + raise TimerError('timer is not running') + self._t_last = time() + return self._t_last - self._t_start + + def since_last_check(self): + """Time since the last checking. + + Either :func:`since_start` or :func:`since_last_check` is a checking + operation. + + Returns (float): Time in seconds. + """ + if not self._is_running: + raise TimerError('timer is not running') + dur = time() - self._t_last + self._t_last = time() + return dur + + +_g_timers = {} # global timers + + +def check_time(timer_id): + """Add check points in a single line. + + This method is suitable for running a task on a list of items. A timer will + be registered when the method is called for the first time. + + :Example: + + >>> import time + >>> import annotator.uniformer.mmcv as mmcv + >>> for i in range(1, 6): + >>> # simulate a code block + >>> time.sleep(i) + >>> mmcv.check_time('task1') + 2.000 + 3.000 + 4.000 + 5.000 + + Args: + timer_id (str): Timer identifier. + """ + if timer_id not in _g_timers: + _g_timers[timer_id] = Timer() + return 0 + else: + return _g_timers[timer_id].since_last_check() diff --git a/annotator/uniformer/mmcv/utils/trace.py b/annotator/uniformer/mmcv/utils/trace.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca99dc3eda05ef980d9a4249b50deca8273b6cc --- /dev/null +++ b/annotator/uniformer/mmcv/utils/trace.py @@ -0,0 +1,23 @@ +import warnings + +import torch + +from annotator.uniformer.mmcv.utils import digit_version + + +def is_jit_tracing() -> bool: + if (torch.__version__ != 'parrots' + and digit_version(torch.__version__) >= digit_version('1.6.0')): + on_trace = torch.jit.is_tracing() + # In PyTorch 1.6, torch.jit.is_tracing has a bug. + # Refers to https://github.com/pytorch/pytorch/issues/42448 + if isinstance(on_trace, bool): + return on_trace + else: + return torch._C._is_tracing() + else: + warnings.warn( + 'torch.jit.is_tracing is only supported after v1.6.0. ' + 'Therefore is_tracing returns False automatically. Please ' + 'set on_trace manually if you are using trace.', UserWarning) + return False diff --git a/annotator/uniformer/mmcv/utils/version_utils.py b/annotator/uniformer/mmcv/utils/version_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..963c45a2e8a86a88413ab6c18c22481fb9831985 --- /dev/null +++ b/annotator/uniformer/mmcv/utils/version_utils.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import subprocess +import warnings + +from packaging.version import parse + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + assert 'parrots' not in version_str + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) + + +def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen( + cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + +def get_git_hash(fallback='unknown', digits=None): + """Get the git hash of the current repo. + + Args: + fallback (str, optional): The fallback string when git hash is + unavailable. Defaults to 'unknown'. + digits (int, optional): kept digits of the hash. Defaults to None, + meaning all digits are kept. + + Returns: + str: Git commit hash. + """ + + if digits is not None and not isinstance(digits, int): + raise TypeError('digits must be None or an integer') + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + if digits is not None: + sha = sha[:digits] + except OSError: + sha = fallback + + return sha diff --git a/annotator/uniformer/mmcv/version.py b/annotator/uniformer/mmcv/version.py new file mode 100644 index 0000000000000000000000000000000000000000..1cce4e50bd692d4002e3cac3c545a3fb2efe95d0 --- /dev/null +++ b/annotator/uniformer/mmcv/version.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +__version__ = '1.3.17' + + +def parse_version_info(version_str: str, length: int = 4) -> tuple: + """Parse a version string into a tuple. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into + (2, 0, 0, 0, 'rc', 1) (when length is set to 4). + """ + from packaging.version import parse + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + release.extend(list(version.pre)) + elif version.is_postrelease: + release.extend(list(version.post)) + else: + release.extend([0, 0]) + return tuple(release) + + +version_info = tuple(int(x) for x in __version__.split('.')[:3]) + +__all__ = ['__version__', 'version_info', 'parse_version_info'] diff --git a/annotator/uniformer/mmcv/video/__init__.py b/annotator/uniformer/mmcv/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73199b01dec52820dc6ca0139903536344d5a1eb --- /dev/null +++ b/annotator/uniformer/mmcv/video/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .io import Cache, VideoReader, frames2video +from .optflow import (dequantize_flow, flow_from_bytes, flow_warp, flowread, + flowwrite, quantize_flow, sparse_flow_from_bytes) +from .processing import concat_video, convert_video, cut_video, resize_video + +__all__ = [ + 'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video', + 'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow', + 'dequantize_flow', 'flow_warp', 'flow_from_bytes', 'sparse_flow_from_bytes' +] diff --git a/annotator/uniformer/mmcv/video/io.py b/annotator/uniformer/mmcv/video/io.py new file mode 100644 index 0000000000000000000000000000000000000000..9879154227f640c262853b92c219461c6f67ee8e --- /dev/null +++ b/annotator/uniformer/mmcv/video/io.py @@ -0,0 +1,318 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import OrderedDict + +import cv2 +from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT, + CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH, + CAP_PROP_POS_FRAMES, VideoWriter_fourcc) + +from annotator.uniformer.mmcv.utils import (check_file_exist, mkdir_or_exist, scandir, + track_progress) + + +class Cache: + + def __init__(self, capacity): + self._cache = OrderedDict() + self._capacity = int(capacity) + if capacity <= 0: + raise ValueError('capacity must be a positive integer') + + @property + def capacity(self): + return self._capacity + + @property + def size(self): + return len(self._cache) + + def put(self, key, val): + if key in self._cache: + return + if len(self._cache) >= self.capacity: + self._cache.popitem(last=False) + self._cache[key] = val + + def get(self, key, default=None): + val = self._cache[key] if key in self._cache else default + return val + + +class VideoReader: + """Video class with similar usage to a list object. + + This video warpper class provides convenient apis to access frames. + There exists an issue of OpenCV's VideoCapture class that jumping to a + certain frame may be inaccurate. It is fixed in this class by checking + the position after jumping each time. + Cache is used when decoding videos. So if the same frame is visited for + the second time, there is no need to decode again if it is stored in the + cache. + + :Example: + + >>> import annotator.uniformer.mmcv as mmcv + >>> v = mmcv.VideoReader('sample.mp4') + >>> len(v) # get the total frame number with `len()` + 120 + >>> for img in v: # v is iterable + >>> mmcv.imshow(img) + >>> v[5] # get the 6th frame + """ + + def __init__(self, filename, cache_capacity=10): + # Check whether the video path is a url + if not filename.startswith(('https://', 'http://')): + check_file_exist(filename, 'Video file not found: ' + filename) + self._vcap = cv2.VideoCapture(filename) + assert cache_capacity > 0 + self._cache = Cache(cache_capacity) + self._position = 0 + # get basic info + self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH)) + self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT)) + self._fps = self._vcap.get(CAP_PROP_FPS) + self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT)) + self._fourcc = self._vcap.get(CAP_PROP_FOURCC) + + @property + def vcap(self): + """:obj:`cv2.VideoCapture`: The raw VideoCapture object.""" + return self._vcap + + @property + def opened(self): + """bool: Indicate whether the video is opened.""" + return self._vcap.isOpened() + + @property + def width(self): + """int: Width of video frames.""" + return self._width + + @property + def height(self): + """int: Height of video frames.""" + return self._height + + @property + def resolution(self): + """tuple: Video resolution (width, height).""" + return (self._width, self._height) + + @property + def fps(self): + """float: FPS of the video.""" + return self._fps + + @property + def frame_cnt(self): + """int: Total frames of the video.""" + return self._frame_cnt + + @property + def fourcc(self): + """str: "Four character code" of the video.""" + return self._fourcc + + @property + def position(self): + """int: Current cursor position, indicating frame decoded.""" + return self._position + + def _get_real_position(self): + return int(round(self._vcap.get(CAP_PROP_POS_FRAMES))) + + def _set_real_position(self, frame_id): + self._vcap.set(CAP_PROP_POS_FRAMES, frame_id) + pos = self._get_real_position() + for _ in range(frame_id - pos): + self._vcap.read() + self._position = frame_id + + def read(self): + """Read the next frame. + + If the next frame have been decoded before and in the cache, then + return it directly, otherwise decode, cache and return it. + + Returns: + ndarray or None: Return the frame if successful, otherwise None. + """ + # pos = self._position + if self._cache: + img = self._cache.get(self._position) + if img is not None: + ret = True + else: + if self._position != self._get_real_position(): + self._set_real_position(self._position) + ret, img = self._vcap.read() + if ret: + self._cache.put(self._position, img) + else: + ret, img = self._vcap.read() + if ret: + self._position += 1 + return img + + def get_frame(self, frame_id): + """Get frame by index. + + Args: + frame_id (int): Index of the expected frame, 0-based. + + Returns: + ndarray or None: Return the frame if successful, otherwise None. + """ + if frame_id < 0 or frame_id >= self._frame_cnt: + raise IndexError( + f'"frame_id" must be between 0 and {self._frame_cnt - 1}') + if frame_id == self._position: + return self.read() + if self._cache: + img = self._cache.get(frame_id) + if img is not None: + self._position = frame_id + 1 + return img + self._set_real_position(frame_id) + ret, img = self._vcap.read() + if ret: + if self._cache: + self._cache.put(self._position, img) + self._position += 1 + return img + + def current_frame(self): + """Get the current frame (frame that is just visited). + + Returns: + ndarray or None: If the video is fresh, return None, otherwise + return the frame. + """ + if self._position == 0: + return None + return self._cache.get(self._position - 1) + + def cvt2frames(self, + frame_dir, + file_start=0, + filename_tmpl='{:06d}.jpg', + start=0, + max_num=0, + show_progress=True): + """Convert a video to frame images. + + Args: + frame_dir (str): Output directory to store all the frame images. + file_start (int): Filenames will start from the specified number. + filename_tmpl (str): Filename template with the index as the + placeholder. + start (int): The starting frame index. + max_num (int): Maximum number of frames to be written. + show_progress (bool): Whether to show a progress bar. + """ + mkdir_or_exist(frame_dir) + if max_num == 0: + task_num = self.frame_cnt - start + else: + task_num = min(self.frame_cnt - start, max_num) + if task_num <= 0: + raise ValueError('start must be less than total frame number') + if start > 0: + self._set_real_position(start) + + def write_frame(file_idx): + img = self.read() + if img is None: + return + filename = osp.join(frame_dir, filename_tmpl.format(file_idx)) + cv2.imwrite(filename, img) + + if show_progress: + track_progress(write_frame, range(file_start, + file_start + task_num)) + else: + for i in range(task_num): + write_frame(file_start + i) + + def __len__(self): + return self.frame_cnt + + def __getitem__(self, index): + if isinstance(index, slice): + return [ + self.get_frame(i) + for i in range(*index.indices(self.frame_cnt)) + ] + # support negative indexing + if index < 0: + index += self.frame_cnt + if index < 0: + raise IndexError('index out of range') + return self.get_frame(index) + + def __iter__(self): + self._set_real_position(0) + return self + + def __next__(self): + img = self.read() + if img is not None: + return img + else: + raise StopIteration + + next = __next__ + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._vcap.release() + + +def frames2video(frame_dir, + video_file, + fps=30, + fourcc='XVID', + filename_tmpl='{:06d}.jpg', + start=0, + end=0, + show_progress=True): + """Read the frame images from a directory and join them as a video. + + Args: + frame_dir (str): The directory containing video frames. + video_file (str): Output filename. + fps (float): FPS of the output video. + fourcc (str): Fourcc of the output video, this should be compatible + with the output file type. + filename_tmpl (str): Filename template with the index as the variable. + start (int): Starting frame index. + end (int): Ending frame index. + show_progress (bool): Whether to show a progress bar. + """ + if end == 0: + ext = filename_tmpl.split('.')[-1] + end = len([name for name in scandir(frame_dir, ext)]) + first_file = osp.join(frame_dir, filename_tmpl.format(start)) + check_file_exist(first_file, 'The start frame not found: ' + first_file) + img = cv2.imread(first_file) + height, width = img.shape[:2] + resolution = (width, height) + vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps, + resolution) + + def write_frame(file_idx): + filename = osp.join(frame_dir, filename_tmpl.format(file_idx)) + img = cv2.imread(filename) + vwriter.write(img) + + if show_progress: + track_progress(write_frame, range(start, end)) + else: + for i in range(start, end): + write_frame(i) + vwriter.release() diff --git a/annotator/uniformer/mmcv/video/optflow.py b/annotator/uniformer/mmcv/video/optflow.py new file mode 100644 index 0000000000000000000000000000000000000000..84160f8d6ef9fceb5a2f89e7481593109fc1905d --- /dev/null +++ b/annotator/uniformer/mmcv/video/optflow.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import cv2 +import numpy as np + +from annotator.uniformer.mmcv.arraymisc import dequantize, quantize +from annotator.uniformer.mmcv.image import imread, imwrite +from annotator.uniformer.mmcv.utils import is_str + + +def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_or_path (ndarray or str): A flow map or filepath. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if isinstance(flow_or_path, np.ndarray): + if (flow_or_path.ndim != 3) or (flow_or_path.shape[-1] != 2): + raise ValueError(f'Invalid flow with shape {flow_or_path.shape}') + return flow_or_path + elif not is_str(flow_or_path): + raise TypeError(f'"flow_or_path" must be a filename or numpy array, ' + f'not {type(flow_or_path)}') + + if not quantize: + with open(flow_or_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_or_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_or_path}, ' + 'header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + else: + assert concat_axis in [0, 1] + cat_flow = imread(flow_or_path, flag='unchanged') + if cat_flow.ndim != 2: + raise IOError( + f'{flow_or_path} is not a valid quantized flow file, ' + f'its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + imwrite(dxdy, filename) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [ + quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] + ] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'): + """Use flow to warp img. + + Args: + img (ndarray, float or uint8): Image to be warped. + flow (ndarray, float): Optical Flow. + filling_value (int): The missing pixels will be set with filling_value. + interpolate_mode (str): bilinear -> Bilinear Interpolation; + nearest -> Nearest Neighbor. + + Returns: + ndarray: Warped image with the same shape of img + """ + warnings.warn('This function is just for prototyping and cannot ' + 'guarantee the computational efficiency.') + assert flow.ndim == 3, 'Flow must be in 3D arrays.' + height = flow.shape[0] + width = flow.shape[1] + channels = img.shape[2] + + output = np.ones( + (height, width, channels), dtype=img.dtype) * filling_value + + grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2) + dx = grid[:, :, 0] + flow[:, :, 1] + dy = grid[:, :, 1] + flow[:, :, 0] + sx = np.floor(dx).astype(int) + sy = np.floor(dy).astype(int) + valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1) + + if interpolate_mode == 'nearest': + output[valid, :] = img[dx[valid].round().astype(int), + dy[valid].round().astype(int), :] + elif interpolate_mode == 'bilinear': + # dirty walkround for integer positions + eps_ = 1e-6 + dx, dy = dx + eps_, dy + eps_ + left_top_ = img[np.floor(dx[valid]).astype(int), + np.floor(dy[valid]).astype(int), :] * ( + np.ceil(dx[valid]) - dx[valid])[:, None] * ( + np.ceil(dy[valid]) - dy[valid])[:, None] + left_down_ = img[np.ceil(dx[valid]).astype(int), + np.floor(dy[valid]).astype(int), :] * ( + dx[valid] - np.floor(dx[valid]))[:, None] * ( + np.ceil(dy[valid]) - dy[valid])[:, None] + right_top_ = img[np.floor(dx[valid]).astype(int), + np.ceil(dy[valid]).astype(int), :] * ( + np.ceil(dx[valid]) - dx[valid])[:, None] * ( + dy[valid] - np.floor(dy[valid]))[:, None] + right_down_ = img[np.ceil(dx[valid]).astype(int), + np.ceil(dy[valid]).astype(int), :] * ( + dx[valid] - np.floor(dx[valid]))[:, None] * ( + dy[valid] - np.floor(dy[valid]))[:, None] + output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_ + else: + raise NotImplementedError( + 'We only support interpolation modes of nearest and bilinear, ' + f'but got {interpolate_mode}.') + return output.astype(img.dtype) + + +def flow_from_bytes(content): + """Read dense optical flow from bytes. + + .. note:: + This load optical flow function works for FlyingChairs, FlyingThings3D, + Sintel, FlyingChairsOcc datasets, but cannot load the data from + ChairsSDHom. + + Args: + content (bytes): Optical flow bytes got from files or other streams. + + Returns: + ndarray: Loaded optical flow with the shape (H, W, 2). + """ + + # header in first 4 bytes + header = content[:4] + if header.decode('utf-8') != 'PIEH': + raise Exception('Flow file header does not contain PIEH') + # width in second 4 bytes + width = np.frombuffer(content[4:], np.int32, 1).squeeze() + # height in third 4 bytes + height = np.frombuffer(content[8:], np.int32, 1).squeeze() + # after first 12 bytes, all bytes are flow + flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape( + (height, width, 2)) + + return flow + + +def sparse_flow_from_bytes(content): + """Read the optical flow in KITTI datasets from bytes. + + This function is modified from RAFT load the `KITTI datasets + `_. + + Args: + content (bytes): Optical flow bytes got from files or other streams. + + Returns: + Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2) + and flow valid mask with the shape (H, W). + """ # nopa + + content = np.frombuffer(content, np.uint8) + flow = cv2.imdecode(content, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + # flow shape (H, W, 2) valid shape (H, W) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid diff --git a/annotator/uniformer/mmcv/video/processing.py b/annotator/uniformer/mmcv/video/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..3d90b96e0823d5f116755e7f498d25d17017224a --- /dev/null +++ b/annotator/uniformer/mmcv/video/processing.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import subprocess +import tempfile + +from annotator.uniformer.mmcv.utils import requires_executable + + +@requires_executable('ffmpeg') +def convert_video(in_file, + out_file, + print_cmd=False, + pre_options='', + **kwargs): + """Convert a video with ffmpeg. + + This provides a general api to ffmpeg, the executed command is:: + + `ffmpeg -y -i ` + + Options(kwargs) are mapped to ffmpeg commands with the following rules: + + - key=val: "-key val" + - key=True: "-key" + - key=False: "" + + Args: + in_file (str): Input video filename. + out_file (str): Output video filename. + pre_options (str): Options appears before "-i ". + print_cmd (bool): Whether to print the final ffmpeg command. + """ + options = [] + for k, v in kwargs.items(): + if isinstance(v, bool): + if v: + options.append(f'-{k}') + elif k == 'log_level': + assert v in [ + 'quiet', 'panic', 'fatal', 'error', 'warning', 'info', + 'verbose', 'debug', 'trace' + ] + options.append(f'-loglevel {v}') + else: + options.append(f'-{k} {v}') + cmd = f'ffmpeg -y {pre_options} -i {in_file} {" ".join(options)} ' \ + f'{out_file}' + if print_cmd: + print(cmd) + subprocess.call(cmd, shell=True) + + +@requires_executable('ffmpeg') +def resize_video(in_file, + out_file, + size=None, + ratio=None, + keep_ar=False, + log_level='info', + print_cmd=False): + """Resize a video. + + Args: + in_file (str): Input video filename. + out_file (str): Output video filename. + size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1). + ratio (tuple or float): Expected resize ratio, (2, 0.5) means + (w*2, h*0.5). + keep_ar (bool): Whether to keep original aspect ratio. + log_level (str): Logging level of ffmpeg. + print_cmd (bool): Whether to print the final ffmpeg command. + """ + if size is None and ratio is None: + raise ValueError('expected size or ratio must be specified') + if size is not None and ratio is not None: + raise ValueError('size and ratio cannot be specified at the same time') + options = {'log_level': log_level} + if size: + if not keep_ar: + options['vf'] = f'scale={size[0]}:{size[1]}' + else: + options['vf'] = f'scale=w={size[0]}:h={size[1]}:' \ + 'force_original_aspect_ratio=decrease' + else: + if not isinstance(ratio, tuple): + ratio = (ratio, ratio) + options['vf'] = f'scale="trunc(iw*{ratio[0]}):trunc(ih*{ratio[1]})"' + convert_video(in_file, out_file, print_cmd, **options) + + +@requires_executable('ffmpeg') +def cut_video(in_file, + out_file, + start=None, + end=None, + vcodec=None, + acodec=None, + log_level='info', + print_cmd=False): + """Cut a clip from a video. + + Args: + in_file (str): Input video filename. + out_file (str): Output video filename. + start (None or float): Start time (in seconds). + end (None or float): End time (in seconds). + vcodec (None or str): Output video codec, None for unchanged. + acodec (None or str): Output audio codec, None for unchanged. + log_level (str): Logging level of ffmpeg. + print_cmd (bool): Whether to print the final ffmpeg command. + """ + options = {'log_level': log_level} + if vcodec is None: + options['vcodec'] = 'copy' + if acodec is None: + options['acodec'] = 'copy' + if start: + options['ss'] = start + else: + start = 0 + if end: + options['t'] = end - start + convert_video(in_file, out_file, print_cmd, **options) + + +@requires_executable('ffmpeg') +def concat_video(video_list, + out_file, + vcodec=None, + acodec=None, + log_level='info', + print_cmd=False): + """Concatenate multiple videos into a single one. + + Args: + video_list (list): A list of video filenames + out_file (str): Output video filename + vcodec (None or str): Output video codec, None for unchanged + acodec (None or str): Output audio codec, None for unchanged + log_level (str): Logging level of ffmpeg. + print_cmd (bool): Whether to print the final ffmpeg command. + """ + tmp_filehandler, tmp_filename = tempfile.mkstemp(suffix='.txt', text=True) + with open(tmp_filename, 'w') as f: + for filename in video_list: + f.write(f'file {osp.abspath(filename)}\n') + options = {'log_level': log_level} + if vcodec is None: + options['vcodec'] = 'copy' + if acodec is None: + options['acodec'] = 'copy' + convert_video( + tmp_filename, + out_file, + print_cmd, + pre_options='-f concat -safe 0', + **options) + os.close(tmp_filehandler) + os.remove(tmp_filename) diff --git a/annotator/uniformer/mmcv_custom/__init__.py b/annotator/uniformer/mmcv_custom/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b958738b9fd93bfcec239c550df1d9a44b8c536 --- /dev/null +++ b/annotator/uniformer/mmcv_custom/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- + +from .checkpoint import load_checkpoint + +__all__ = ['load_checkpoint'] \ No newline at end of file diff --git a/annotator/uniformer/mmcv_custom/checkpoint.py b/annotator/uniformer/mmcv_custom/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..19b87fef0a52d31babcdb3edb8f3089b6420173f --- /dev/null +++ b/annotator/uniformer/mmcv_custom/checkpoint.py @@ -0,0 +1,500 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import io +import os +import os.path as osp +import pkgutil +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import torch +import torchvision +from torch.optim import Optimizer +from torch.utils import model_zoo +from torch.nn import functional as F + +import annotator.uniformer.mmcv as mmcv +from annotator.uniformer.mmcv.fileio import FileClient +from annotator.uniformer.mmcv.fileio import load as load_file +from annotator.uniformer.mmcv.parallel import is_module_wrapper +from annotator.uniformer.mmcv.utils import mkdir_or_exist +from annotator.uniformer.mmcv.runner import get_dist_info + +ENV_MMCV_HOME = 'MMCV_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_mmcv_home(): + mmcv_home = os.path.expanduser( + os.getenv( + ENV_MMCV_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) + + mkdir_or_exist(mmcv_home) + return mmcv_home + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def load_url_dist(url, model_dir=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + return checkpoint + + +def load_pavimodel_dist(model_path, map_location=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load( + downloaded_file, map_location=map_location) + return checkpoint + + +def load_fileclient_dist(filename, backend, map_location): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + allowed_backends = ['ceph'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + if rank == 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +def _load_checkpoint(filename, map_location=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict | OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + if filename.startswith('modelzoo://'): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead') + model_urls = get_torchvision_models() + model_name = filename[11:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('torchvision://'): + model_urls = get_torchvision_models() + model_name = filename[14:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('open-mmlab://'): + model_urls = get_external_models() + model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' + f'of open-mmlab://{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_url_dist(model_url) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + elif filename.startswith('mmcls://'): + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_url_dist(model_urls[model_name]) + checkpoint = _process_mmcls_checkpoint(checkpoint) + elif filename.startswith(('http://', 'https://')): + checkpoint = load_url_dist(filename) + elif filename.startswith('pavi://'): + model_path = filename[7:] + checkpoint = load_pavimodel_dist(model_path, map_location=map_location) + elif filename.startswith('s3://'): + checkpoint = load_fileclient_dist( + filename, backend='ceph', map_location=map_location) + else: + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # for MoBY, load model of online branch + if sorted(list(state_dict.keys()))[0].startswith('encoder'): + state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = model.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H*W: + logger.warning("Error in loading absolute_pos_embed, pass") + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) + + # interpolate position bias table if needed + relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = model.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f"Error in loading {table_key}, pass") + else: + if L1 != L2: + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, 'wb') as f: + torch.save(checkpoint, f) + f.flush() \ No newline at end of file diff --git a/annotator/uniformer/mmseg/apis/__init__.py b/annotator/uniformer/mmseg/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..170724be38de42daf2bc1a1910e181d68818f165 --- /dev/null +++ b/annotator/uniformer/mmseg/apis/__init__.py @@ -0,0 +1,9 @@ +from .inference import inference_segmentor, init_segmentor, show_result_pyplot +from .test import multi_gpu_test, single_gpu_test +from .train import get_root_logger, set_random_seed, train_segmentor + +__all__ = [ + 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', + 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', + 'show_result_pyplot' +] diff --git a/annotator/uniformer/mmseg/apis/inference.py b/annotator/uniformer/mmseg/apis/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..90bc1c0c68525734bd6793f07c15fe97d3c8342c --- /dev/null +++ b/annotator/uniformer/mmseg/apis/inference.py @@ -0,0 +1,136 @@ +import matplotlib.pyplot as plt +import annotator.uniformer.mmcv as mmcv +import torch +from annotator.uniformer.mmcv.parallel import collate, scatter +from annotator.uniformer.mmcv.runner import load_checkpoint + +from annotator.uniformer.mmseg.datasets.pipelines import Compose +from annotator.uniformer.mmseg.models import build_segmentor + + +def init_segmentor(config, checkpoint=None, device='cuda:0'): + """Initialize a segmentor from config file. + + Args: + config (str or :obj:`mmcv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. + Returns: + nn.Module: The constructed segmentor. + """ + if isinstance(config, str): + config = mmcv.Config.fromfile(config) + elif not isinstance(config, mmcv.Config): + raise TypeError('config must be a filename or Config object, ' + 'but got {}'.format(type(config))) + config.model.pretrained = None + config.model.train_cfg = None + model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + model.CLASSES = checkpoint['meta']['CLASSES'] + model.PALETTE = checkpoint['meta']['PALETTE'] + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +class LoadImage: + """A simple pipeline to load image.""" + + def __call__(self, results): + """Call function to load images into results. + + Args: + results (dict): A result dict contains the file name + of the image to be read. + + Returns: + dict: ``results`` will be returned containing loaded image. + """ + + if isinstance(results['img'], str): + results['filename'] = results['img'] + results['ori_filename'] = results['img'] + else: + results['filename'] = None + results['ori_filename'] = None + img = mmcv.imread(results['img']) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + return results + + +def inference_segmentor(model, img): + """Inference image(s) with the segmentor. + + Args: + model (nn.Module): The loaded segmentor. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + (list[Tensor]): The segmentation result. + """ + cfg = model.cfg + device = next(model.parameters()).device # model device + # build the data pipeline + test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + if next(model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [device])[0] + else: + data['img_metas'] = [i.data[0] for i in data['img_metas']] + + # forward the model + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + return result + + +def show_result_pyplot(model, + img, + result, + palette=None, + fig_size=(15, 10), + opacity=0.5, + title='', + block=True): + """Visualize the segmentation results on the image. + + Args: + model (nn.Module): The loaded segmentor. + img (str or np.ndarray): Image filename or loaded image. + result (list): The segmentation result. + palette (list[list[int]]] | None): The palette of segmentation + map. If None is given, random palette will be generated. + Default: None + fig_size (tuple): Figure size of the pyplot figure. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + title (str): The title of pyplot figure. + Default is ''. + block (bool): Whether to block the pyplot figure. + Default is True. + """ + if hasattr(model, 'module'): + model = model.module + img = model.show_result( + img, result, palette=palette, show=False, opacity=opacity) + # plt.figure(figsize=fig_size) + # plt.imshow(mmcv.bgr2rgb(img)) + # plt.title(title) + # plt.tight_layout() + # plt.show(block=block) + return mmcv.bgr2rgb(img) diff --git a/annotator/uniformer/mmseg/apis/test.py b/annotator/uniformer/mmseg/apis/test.py new file mode 100644 index 0000000000000000000000000000000000000000..e574eb7da04f09a59cf99ff953c36468ae87a326 --- /dev/null +++ b/annotator/uniformer/mmseg/apis/test.py @@ -0,0 +1,238 @@ +import os.path as osp +import pickle +import shutil +import tempfile + +import annotator.uniformer.mmcv as mmcv +import numpy as np +import torch +import torch.distributed as dist +from annotator.uniformer.mmcv.image import tensor2imgs +from annotator.uniformer.mmcv.runner import get_dist_info + + +def np2tmp(array, temp_file_name=None): + """Save ndarray to local numpy file. + + Args: + array (ndarray): Ndarray to save. + temp_file_name (str): Numpy file name. If 'temp_file_name=None', this + function will generate a file name with tempfile.NamedTemporaryFile + to save ndarray. Default: None. + + Returns: + str: The numpy file name. + """ + + if temp_file_name is None: + temp_file_name = tempfile.NamedTemporaryFile( + suffix='.npy', delete=False).name + np.save(temp_file_name, array) + return temp_file_name + + +def single_gpu_test(model, + data_loader, + show=False, + out_dir=None, + efficient_test=False, + opacity=0.5): + """Test with single GPU. + + Args: + model (nn.Module): Model to be tested. + data_loader (utils.data.Dataloader): Pytorch data loader. + show (bool): Whether show results during inference. Default: False. + out_dir (str, optional): If specified, the results will be dumped into + the directory to save output results. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + Returns: + list: The prediction results. + """ + + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + + if show or out_dir: + img_tensor = data['img'][0] + img_metas = data['img_metas'][0].data[0] + imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) + assert len(imgs) == len(img_metas) + + for img, img_meta in zip(imgs, img_metas): + h, w, _ = img_meta['img_shape'] + img_show = img[:h, :w, :] + + ori_h, ori_w = img_meta['ori_shape'][:-1] + img_show = mmcv.imresize(img_show, (ori_w, ori_h)) + + if out_dir: + out_file = osp.join(out_dir, img_meta['ori_filename']) + else: + out_file = None + + model.module.show_result( + img_show, + result, + palette=dataset.PALETTE, + show=show, + out_file=out_file, + opacity=opacity) + + if isinstance(result, list): + if efficient_test: + result = [np2tmp(_) for _ in result] + results.extend(result) + else: + if efficient_test: + result = np2tmp(result) + results.append(result) + + batch_size = len(result) + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model, + data_loader, + tmpdir=None, + gpu_collect=False, + efficient_test=False): + """Test model with multiple gpus. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' + it encodes results to gpu tensors and use gpu communication for results + collection. On cpu mode it saves the results on different gpus to 'tmpdir' + and collects them by the rank 0 worker. + + Args: + model (nn.Module): Model to be tested. + data_loader (utils.data.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + + Returns: + list: The prediction results. + """ + + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + + if isinstance(result, list): + if efficient_test: + result = [np2tmp(_) for _ in result] + results.extend(result) + else: + if efficient_test: + result = np2tmp(result) + results.append(result) + + if rank == 0: + batch_size = data['img'][0].size(0) + for _ in range(batch_size * world_size): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results + + +def collect_results_cpu(result_part, size, tmpdir=None): + """Collect results with CPU.""" + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + tmpdir = tempfile.mkdtemp() + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmcv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank))) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i)) + part_list.append(mmcv.load(part_file)) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + """Collect results with GPU.""" + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_list.append( + pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results diff --git a/annotator/uniformer/mmseg/apis/train.py b/annotator/uniformer/mmseg/apis/train.py new file mode 100644 index 0000000000000000000000000000000000000000..63f319a919ff023931a6a663e668f27dd1a07a2e --- /dev/null +++ b/annotator/uniformer/mmseg/apis/train.py @@ -0,0 +1,116 @@ +import random +import warnings + +import numpy as np +import torch +from annotator.uniformer.mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from annotator.uniformer.mmcv.runner import build_optimizer, build_runner + +from annotator.uniformer.mmseg.core import DistEvalHook, EvalHook +from annotator.uniformer.mmseg.datasets import build_dataloader, build_dataset +from annotator.uniformer.mmseg.utils import get_root_logger + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def train_segmentor(model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + meta=None): + """Launch segmentor training.""" + logger = get_root_logger(cfg.log_level) + + # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + # cfg.gpus will be ignored if distributed + len(cfg.gpu_ids), + dist=distributed, + seed=cfg.seed, + drop_last=True) for ds in dataset + ] + + # put model on gpus + if distributed: + find_unused_parameters = cfg.get('find_unused_parameters', False) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + model = MMDataParallel( + model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) + + # build runner + optimizer = build_optimizer(model, cfg.optimizer) + + if cfg.get('runner') is None: + cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} + warnings.warn( + 'config is now expected to have a `runner` section, ' + 'please set `runner` in your config.', UserWarning) + + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) + + # register hooks + runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, + cfg.checkpoint_config, cfg.log_config, + cfg.get('momentum_config', None)) + + # an ugly walkaround to make the .log and .log.json filenames the same + runner.timestamp = timestamp + + # register eval hooks + if validate: + val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) + val_dataloader = build_dataloader( + val_dataset, + samples_per_gpu=1, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False) + eval_cfg = cfg.get('evaluation', {}) + eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' + eval_hook = DistEvalHook if distributed else EvalHook + runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW') + + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow) diff --git a/annotator/uniformer/mmseg/core/__init__.py b/annotator/uniformer/mmseg/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..965605587211b7bf0bd6bc3acdbb33dd49cab023 --- /dev/null +++ b/annotator/uniformer/mmseg/core/__init__.py @@ -0,0 +1,3 @@ +from .evaluation import * # noqa: F401, F403 +from .seg import * # noqa: F401, F403 +from .utils import * # noqa: F401, F403 diff --git a/annotator/uniformer/mmseg/core/evaluation/__init__.py b/annotator/uniformer/mmseg/core/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7cc4b23413a0639e9de00eeb0bf600632d2c6cd --- /dev/null +++ b/annotator/uniformer/mmseg/core/evaluation/__init__.py @@ -0,0 +1,8 @@ +from .class_names import get_classes, get_palette +from .eval_hooks import DistEvalHook, EvalHook +from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou + +__all__ = [ + 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', + 'eval_metrics', 'get_classes', 'get_palette' +] diff --git a/annotator/uniformer/mmseg/core/evaluation/class_names.py b/annotator/uniformer/mmseg/core/evaluation/class_names.py new file mode 100644 index 0000000000000000000000000000000000000000..ffae816cf980ce4b03e491cc0c4298cb823797e6 --- /dev/null +++ b/annotator/uniformer/mmseg/core/evaluation/class_names.py @@ -0,0 +1,152 @@ +import annotator.uniformer.mmcv as mmcv + + +def cityscapes_classes(): + """Cityscapes class names for external use.""" + return [ + 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def ade_classes(): + """ADE20K class names for external use.""" + return [ + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag' + ] + + +def voc_classes(): + """Pascal VOC class names for external use.""" + return [ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', + 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', + 'tvmonitor' + ] + + +def cityscapes_palette(): + """Cityscapes palette for external use.""" + return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], + [0, 0, 230], [119, 11, 32]] + + +def ade_palette(): + """ADE20K palette for external use.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + +def voc_palette(): + """Pascal VOC palette for external use.""" + return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + +dataset_aliases = { + 'cityscapes': ['cityscapes'], + 'ade': ['ade', 'ade20k'], + 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'] +} + + +def get_classes(dataset): + """Get class names of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if mmcv.is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_classes()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels + + +def get_palette(dataset): + """Get class palette (RGB) of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if mmcv.is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_palette()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels diff --git a/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py b/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc100c8f96e817a6ed2666f7c9f762af2463b48 --- /dev/null +++ b/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py @@ -0,0 +1,109 @@ +import os.path as osp + +from annotator.uniformer.mmcv.runner import DistEvalHook as _DistEvalHook +from annotator.uniformer.mmcv.runner import EvalHook as _EvalHook + + +class EvalHook(_EvalHook): + """Single GPU EvalHook, with efficient test support. + + Args: + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + Default: False. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + Returns: + list: The prediction results. + """ + + greater_keys = ['mIoU', 'mAcc', 'aAcc'] + + def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): + super().__init__(*args, by_epoch=by_epoch, **kwargs) + self.efficient_test = efficient_test + + def after_train_iter(self, runner): + """After train epoch hook. + + Override default ``single_gpu_test``. + """ + if self.by_epoch or not self.every_n_iters(runner, self.interval): + return + from annotator.uniformer.mmseg.apis import single_gpu_test + runner.log_buffer.clear() + results = single_gpu_test( + runner.model, + self.dataloader, + show=False, + efficient_test=self.efficient_test) + self.evaluate(runner, results) + + def after_train_epoch(self, runner): + """After train epoch hook. + + Override default ``single_gpu_test``. + """ + if not self.by_epoch or not self.every_n_epochs(runner, self.interval): + return + from annotator.uniformer.mmseg.apis import single_gpu_test + runner.log_buffer.clear() + results = single_gpu_test(runner.model, self.dataloader, show=False) + self.evaluate(runner, results) + + +class DistEvalHook(_DistEvalHook): + """Distributed EvalHook, with efficient test support. + + Args: + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + Default: False. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + Returns: + list: The prediction results. + """ + + greater_keys = ['mIoU', 'mAcc', 'aAcc'] + + def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): + super().__init__(*args, by_epoch=by_epoch, **kwargs) + self.efficient_test = efficient_test + + def after_train_iter(self, runner): + """After train epoch hook. + + Override default ``multi_gpu_test``. + """ + if self.by_epoch or not self.every_n_iters(runner, self.interval): + return + from annotator.uniformer.mmseg.apis import multi_gpu_test + runner.log_buffer.clear() + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=osp.join(runner.work_dir, '.eval_hook'), + gpu_collect=self.gpu_collect, + efficient_test=self.efficient_test) + if runner.rank == 0: + print('\n') + self.evaluate(runner, results) + + def after_train_epoch(self, runner): + """After train epoch hook. + + Override default ``multi_gpu_test``. + """ + if not self.by_epoch or not self.every_n_epochs(runner, self.interval): + return + from annotator.uniformer.mmseg.apis import multi_gpu_test + runner.log_buffer.clear() + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=osp.join(runner.work_dir, '.eval_hook'), + gpu_collect=self.gpu_collect) + if runner.rank == 0: + print('\n') + self.evaluate(runner, results) diff --git a/annotator/uniformer/mmseg/core/evaluation/metrics.py b/annotator/uniformer/mmseg/core/evaluation/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..16c7dd47cadd53cf1caaa194e28a343f2aacc599 --- /dev/null +++ b/annotator/uniformer/mmseg/core/evaluation/metrics.py @@ -0,0 +1,326 @@ +from collections import OrderedDict + +import annotator.uniformer.mmcv as mmcv +import numpy as np +import torch + + +def f_score(precision, recall, beta=1): + """calcuate the f-score value. + + Args: + precision (float | torch.Tensor): The precision value. + recall (float | torch.Tensor): The recall value. + beta (int): Determines the weight of recall in the combined score. + Default: False. + + Returns: + [torch.tensor]: The f-score value. + """ + score = (1 + beta**2) * (precision * recall) / ( + (beta**2 * precision) + recall) + return score + + +def intersect_and_union(pred_label, + label, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): + """Calculate intersection and Union. + + Args: + pred_label (ndarray | str): Prediction segmentation map + or predict result filename. + label (ndarray | str): Ground truth segmentation map + or label filename. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. The parameter will + work only when label is str. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. The parameter will + work only when label is str. Default: False. + + Returns: + torch.Tensor: The intersection of prediction and ground truth + histogram on all classes. + torch.Tensor: The union of prediction and ground truth histogram on + all classes. + torch.Tensor: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. + """ + + if isinstance(pred_label, str): + pred_label = torch.from_numpy(np.load(pred_label)) + else: + pred_label = torch.from_numpy((pred_label)) + + if isinstance(label, str): + label = torch.from_numpy( + mmcv.imread(label, flag='unchanged', backend='pillow')) + else: + label = torch.from_numpy(label) + + if label_map is not None: + for old_id, new_id in label_map.items(): + label[label == old_id] = new_id + if reduce_zero_label: + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect = torch.histc( + intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_pred_label = torch.histc( + pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_label = torch.histc( + label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_union = area_pred_label + area_label - area_intersect + return area_intersect, area_union, area_pred_label, area_label + + +def total_intersect_and_union(results, + gt_seg_maps, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): + """Calculate Total Intersection and Union. + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + + Returns: + ndarray: The intersection of prediction and ground truth histogram + on all classes. + ndarray: The union of prediction and ground truth histogram on all + classes. + ndarray: The prediction histogram on all classes. + ndarray: The ground truth histogram on all classes. + """ + num_imgs = len(results) + assert len(gt_seg_maps) == num_imgs + total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) + for i in range(num_imgs): + area_intersect, area_union, area_pred_label, area_label = \ + intersect_and_union( + results[i], gt_seg_maps[i], num_classes, ignore_index, + label_map, reduce_zero_label) + total_area_intersect += area_intersect + total_area_union += area_union + total_area_pred_label += area_pred_label + total_area_label += area_label + return total_area_intersect, total_area_union, total_area_pred_label, \ + total_area_label + + +def mean_iou(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): + """Calculate Mean Intersection and Union (mIoU) + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + + Returns: + dict[str, float | ndarray]: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category IoU, shape (num_classes, ). + """ + iou_result = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mIoU'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label) + return iou_result + + +def mean_dice(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): + """Calculate Mean Dice (mDice) + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + + Returns: + dict[str, float | ndarray]: Default metrics. + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category dice, shape (num_classes, ). + """ + + dice_result = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mDice'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label) + return dice_result + + +def mean_fscore(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False, + beta=1): + """Calculate Mean Intersection and Union (mIoU) + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + beta (int): Determines the weight of recall in the combined score. + Default: False. + + + Returns: + dict[str, float | ndarray]: Default metrics. + float: Overall accuracy on all images. + ndarray: Per category recall, shape (num_classes, ). + ndarray: Per category precision, shape (num_classes, ). + ndarray: Per category f-score, shape (num_classes, ). + """ + fscore_result = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mFscore'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label, + beta=beta) + return fscore_result + + +def eval_metrics(results, + gt_seg_maps, + num_classes, + ignore_index, + metrics=['mIoU'], + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False, + beta=1): + """Calculate evaluation metrics + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evaluation metrics, shape (num_classes, ). + """ + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError('metrics {} is not supported'.format(metrics)) + + total_area_intersect, total_area_union, total_area_pred_label, \ + total_area_label = total_intersect_and_union( + results, gt_seg_maps, num_classes, ignore_index, label_map, + reduce_zero_label) + all_acc = total_area_intersect.sum() / total_area_label.sum() + ret_metrics = OrderedDict({'aAcc': all_acc}) + for metric in metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + acc = total_area_intersect / total_area_label + ret_metrics['IoU'] = iou + ret_metrics['Acc'] = acc + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + acc = total_area_intersect / total_area_label + ret_metrics['Dice'] = dice + ret_metrics['Acc'] = acc + elif metric == 'mFscore': + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor( + [f_score(x[0], x[1], beta) for x in zip(precision, recall)]) + ret_metrics['Fscore'] = f_value + ret_metrics['Precision'] = precision + ret_metrics['Recall'] = recall + + ret_metrics = { + metric: value.numpy() + for metric, value in ret_metrics.items() + } + if nan_to_num is not None: + ret_metrics = OrderedDict({ + metric: np.nan_to_num(metric_value, nan=nan_to_num) + for metric, metric_value in ret_metrics.items() + }) + return ret_metrics diff --git a/annotator/uniformer/mmseg/core/seg/__init__.py b/annotator/uniformer/mmseg/core/seg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93bc129b685e4a3efca2cc891729981b2865900d --- /dev/null +++ b/annotator/uniformer/mmseg/core/seg/__init__.py @@ -0,0 +1,4 @@ +from .builder import build_pixel_sampler +from .sampler import BasePixelSampler, OHEMPixelSampler + +__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] diff --git a/annotator/uniformer/mmseg/core/seg/builder.py b/annotator/uniformer/mmseg/core/seg/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..db61f03d4abb2072f2532ce4429c0842495e015b --- /dev/null +++ b/annotator/uniformer/mmseg/core/seg/builder.py @@ -0,0 +1,8 @@ +from annotator.uniformer.mmcv.utils import Registry, build_from_cfg + +PIXEL_SAMPLERS = Registry('pixel sampler') + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) diff --git a/annotator/uniformer/mmseg/core/seg/sampler/__init__.py b/annotator/uniformer/mmseg/core/seg/sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..332b242c03d1c5e80d4577df442a9a037b1816e1 --- /dev/null +++ b/annotator/uniformer/mmseg/core/seg/sampler/__init__.py @@ -0,0 +1,4 @@ +from .base_pixel_sampler import BasePixelSampler +from .ohem_pixel_sampler import OHEMPixelSampler + +__all__ = ['BasePixelSampler', 'OHEMPixelSampler'] diff --git a/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py b/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..b75b1566c9f18169cee51d4b55d75e0357b69c57 --- /dev/null +++ b/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py @@ -0,0 +1,12 @@ +from abc import ABCMeta, abstractmethod + + +class BasePixelSampler(metaclass=ABCMeta): + """Base class of pixel sampler.""" + + def __init__(self, **kwargs): + pass + + @abstractmethod + def sample(self, seg_logit, seg_label): + """Placeholder for sample function.""" diff --git a/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..88bb10d44026ba9f21756eaea9e550841cd59b9f --- /dev/null +++ b/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py @@ -0,0 +1,76 @@ +import torch +import torch.nn.functional as F + +from ..builder import PIXEL_SAMPLERS +from .base_pixel_sampler import BasePixelSampler + + +@PIXEL_SAMPLERS.register_module() +class OHEMPixelSampler(BasePixelSampler): + """Online Hard Example Mining Sampler for segmentation. + + Args: + context (nn.Module): The context of sampler, subclass of + :obj:`BaseDecodeHead`. + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: None. + min_kept (int, optional): The minimum number of predictions to keep. + Default: 100000. + """ + + def __init__(self, context, thresh=None, min_kept=100000): + super(OHEMPixelSampler, self).__init__() + self.context = context + assert min_kept > 1 + self.thresh = thresh + self.min_kept = min_kept + + def sample(self, seg_logit, seg_label): + """Sample pixels that have high loss or with low prediction confidence. + + Args: + seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) + seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) + + Returns: + torch.Tensor: segmentation weight, shape (N, H, W) + """ + with torch.no_grad(): + assert seg_logit.shape[2:] == seg_label.shape[2:] + assert seg_label.shape[1] == 1 + seg_label = seg_label.squeeze(1).long() + batch_kept = self.min_kept * seg_label.size(0) + valid_mask = seg_label != self.context.ignore_index + seg_weight = seg_logit.new_zeros(size=seg_label.size()) + valid_seg_weight = seg_weight[valid_mask] + if self.thresh is not None: + seg_prob = F.softmax(seg_logit, dim=1) + + tmp_seg_label = seg_label.clone().unsqueeze(1) + tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 + seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) + sort_prob, sort_indices = seg_prob[valid_mask].sort() + + if sort_prob.numel() > 0: + min_threshold = sort_prob[min(batch_kept, + sort_prob.numel() - 1)] + else: + min_threshold = 0.0 + threshold = max(min_threshold, self.thresh) + valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. + else: + losses = self.context.loss_decode( + seg_logit, + seg_label, + weight=None, + ignore_index=self.context.ignore_index, + reduction_override='none') + # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa + _, sort_indices = losses[valid_mask].sort(descending=True) + valid_seg_weight[sort_indices[:batch_kept]] = 1. + + seg_weight[valid_mask] = valid_seg_weight + + return seg_weight diff --git a/annotator/uniformer/mmseg/core/utils/__init__.py b/annotator/uniformer/mmseg/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2678b321c295bcceaef945111ac3524be19d6e4 --- /dev/null +++ b/annotator/uniformer/mmseg/core/utils/__init__.py @@ -0,0 +1,3 @@ +from .misc import add_prefix + +__all__ = ['add_prefix'] diff --git a/annotator/uniformer/mmseg/core/utils/misc.py b/annotator/uniformer/mmseg/core/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..eb862a82bd47c8624db3dd5c6fb6ad8a03b62466 --- /dev/null +++ b/annotator/uniformer/mmseg/core/utils/misc.py @@ -0,0 +1,17 @@ +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f'{prefix}.{name}'] = value + + return outputs diff --git a/annotator/uniformer/mmseg/datasets/__init__.py b/annotator/uniformer/mmseg/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebeaef4a28ef655e43578552a8aef6b77f13a636 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/__init__.py @@ -0,0 +1,19 @@ +from .ade import ADE20KDataset +from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset +from .chase_db1 import ChaseDB1Dataset +from .cityscapes import CityscapesDataset +from .custom import CustomDataset +from .dataset_wrappers import ConcatDataset, RepeatDataset +from .drive import DRIVEDataset +from .hrf import HRFDataset +from .pascal_context import PascalContextDataset, PascalContextDataset59 +from .stare import STAREDataset +from .voc import PascalVOCDataset + +__all__ = [ + 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', + 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', + 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', + 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', + 'STAREDataset' +] diff --git a/annotator/uniformer/mmseg/datasets/ade.py b/annotator/uniformer/mmseg/datasets/ade.py new file mode 100644 index 0000000000000000000000000000000000000000..5913e43775ed4920b6934c855eb5a37c54218ebf --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/ade.py @@ -0,0 +1,84 @@ +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class ADE20KDataset(CustomDataset): + """ADE20K dataset. + + In segmentation map annotation for ADE20K, 0 stands for background, which + is not included in 150 categories. ``reduce_zero_label`` is fixed to True. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + CLASSES = ( + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag') + + PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + def __init__(self, **kwargs): + super(ADE20KDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) diff --git a/annotator/uniformer/mmseg/datasets/builder.py b/annotator/uniformer/mmseg/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0798b14cd8b39fc58d8f2a4930f1e079b5bf8b55 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/builder.py @@ -0,0 +1,169 @@ +import copy +import platform +import random +from functools import partial + +import numpy as np +from annotator.uniformer.mmcv.parallel import collate +from annotator.uniformer.mmcv.runner import get_dist_info +from annotator.uniformer.mmcv.utils import Registry, build_from_cfg +from annotator.uniformer.mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader +from torch.utils.data import DistributedSampler + +if platform.system() != 'Windows': + # https://github.com/pytorch/pytorch/issues/973 + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + hard_limit = rlimit[1] + soft_limit = min(4096, hard_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) + +DATASETS = Registry('dataset') +PIPELINES = Registry('pipeline') + + +def _concat_dataset(cfg, default_args=None): + """Build :obj:`ConcatDataset by.""" + from .dataset_wrappers import ConcatDataset + img_dir = cfg['img_dir'] + ann_dir = cfg.get('ann_dir', None) + split = cfg.get('split', None) + num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 + if ann_dir is not None: + num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 + else: + num_ann_dir = 0 + if split is not None: + num_split = len(split) if isinstance(split, (list, tuple)) else 1 + else: + num_split = 0 + if num_img_dir > 1: + assert num_img_dir == num_ann_dir or num_ann_dir == 0 + assert num_img_dir == num_split or num_split == 0 + else: + assert num_split == num_ann_dir or num_ann_dir <= 1 + num_dset = max(num_split, num_img_dir) + + datasets = [] + for i in range(num_dset): + data_cfg = copy.deepcopy(cfg) + if isinstance(img_dir, (list, tuple)): + data_cfg['img_dir'] = img_dir[i] + if isinstance(ann_dir, (list, tuple)): + data_cfg['ann_dir'] = ann_dir[i] + if isinstance(split, (list, tuple)): + data_cfg['split'] = split[i] + datasets.append(build_dataset(data_cfg, default_args)) + + return ConcatDataset(datasets) + + +def build_dataset(cfg, default_args=None): + """Build datasets.""" + from .dataset_wrappers import ConcatDataset, RepeatDataset + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'RepeatDataset': + dataset = RepeatDataset( + build_dataset(cfg['dataset'], default_args), cfg['times']) + elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( + cfg.get('split', None), (list, tuple)): + dataset = _concat_dataset(cfg, default_args) + else: + dataset = build_from_cfg(cfg, DATASETS, default_args) + + return dataset + + +def build_dataloader(dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + drop_last=False, + pin_memory=True, + dataloader_type='PoolDataLoader', + **kwargs): + """Build PyTorch DataLoader. + + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + + Args: + dataset (Dataset): A PyTorch dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed training. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + seed (int | None): Seed to be used. Default: None. + drop_last (bool): Whether to drop the last incomplete batch in epoch. + Default: False + pin_memory (bool): Whether to use pin_memory in DataLoader. + Default: True + dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader' + kwargs: any keyword argument to be used to initialize DataLoader + + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + if dist: + sampler = DistributedSampler( + dataset, world_size, rank, shuffle=shuffle) + shuffle = False + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + sampler = None + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + assert dataloader_type in ( + 'DataLoader', + 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}' + + if dataloader_type == 'PoolDataLoader': + dataloader = PoolDataLoader + elif dataloader_type == 'DataLoader': + dataloader = DataLoader + + data_loader = dataloader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=pin_memory, + shuffle=shuffle, + worker_init_fn=init_fn, + drop_last=drop_last, + **kwargs) + + return data_loader + + +def worker_init_fn(worker_id, num_workers, rank, seed): + """Worker init func for dataloader. + + The seed of each worker equals to num_worker * rank + worker_id + user_seed + + Args: + worker_id (int): Worker id. + num_workers (int): Number of workers. + rank (int): The rank of current process. + seed (int): The random seed to use. + """ + + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/annotator/uniformer/mmseg/datasets/chase_db1.py b/annotator/uniformer/mmseg/datasets/chase_db1.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc29bea14704a4407f83474610cbc3bef32c708 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/chase_db1.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class ChaseDB1Dataset(CustomDataset): + """Chase_db1 dataset. + + In segmentation map annotation for Chase_db1, 0 stands for background, + which is included in 2 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_1stHO.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(ChaseDB1Dataset, self).__init__( + img_suffix='.png', + seg_map_suffix='_1stHO.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/annotator/uniformer/mmseg/datasets/cityscapes.py b/annotator/uniformer/mmseg/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..81e47a914a1aa2e5458e18669d65ffb742f46fc6 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/cityscapes.py @@ -0,0 +1,217 @@ +import os.path as osp +import tempfile + +import annotator.uniformer.mmcv as mmcv +import numpy as np +from annotator.uniformer.mmcv.utils import print_log +from PIL import Image + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class CityscapesDataset(CustomDataset): + """Cityscapes dataset. + + The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is + fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. + """ + + CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle') + + PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], + [0, 80, 100], [0, 0, 230], [119, 11, 32]] + + def __init__(self, **kwargs): + super(CityscapesDataset, self).__init__( + img_suffix='_leftImg8bit.png', + seg_map_suffix='_gtFine_labelTrainIds.png', + **kwargs) + + @staticmethod + def _convert_to_label_id(result): + """Convert trainId to id for cityscapes.""" + if isinstance(result, str): + result = np.load(result) + import cityscapesscripts.helpers.labels as CSLabels + result_copy = result.copy() + for trainId, label in CSLabels.trainId2label.items(): + result_copy[result == trainId] = label.id + + return result_copy + + def results2img(self, results, imgfile_prefix, to_label_id): + """Write the segmentation results to images. + + Args: + results (list[list | tuple | ndarray]): Testing results of the + dataset. + imgfile_prefix (str): The filename prefix of the png files. + If the prefix is "somepath/xxx", + the png files will be named "somepath/xxx.png". + to_label_id (bool): whether convert output to label_id for + submission + + Returns: + list[str: str]: result txt files which contains corresponding + semantic segmentation images. + """ + mmcv.mkdir_or_exist(imgfile_prefix) + result_files = [] + prog_bar = mmcv.ProgressBar(len(self)) + for idx in range(len(self)): + result = results[idx] + if to_label_id: + result = self._convert_to_label_id(result) + filename = self.img_infos[idx]['filename'] + basename = osp.splitext(osp.basename(filename))[0] + + png_filename = osp.join(imgfile_prefix, f'{basename}.png') + + output = Image.fromarray(result.astype(np.uint8)).convert('P') + import cityscapesscripts.helpers.labels as CSLabels + palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) + for label_id, label in CSLabels.id2label.items(): + palette[label_id] = label.color + + output.putpalette(palette) + output.save(png_filename) + result_files.append(png_filename) + prog_bar.update() + + return result_files + + def format_results(self, results, imgfile_prefix=None, to_label_id=True): + """Format the results into dir (standard format for Cityscapes + evaluation). + + Args: + results (list): Testing results of the dataset. + imgfile_prefix (str | None): The prefix of images files. It + includes the file path and the prefix of filename, e.g., + "a/b/prefix". If not specified, a temp file will be created. + Default: None. + to_label_id (bool): whether convert output to label_id for + submission. Default: False + + Returns: + tuple: (result_files, tmp_dir), result_files is a list containing + the image paths, tmp_dir is the temporal directory created + for saving json/png files when img_prefix is not specified. + """ + + assert isinstance(results, list), 'results must be a list' + assert len(results) == len(self), ( + 'The length of results is not equal to the dataset len: ' + f'{len(results)} != {len(self)}') + + if imgfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + imgfile_prefix = tmp_dir.name + else: + tmp_dir = None + result_files = self.results2img(results, imgfile_prefix, to_label_id) + + return result_files, tmp_dir + + def evaluate(self, + results, + metric='mIoU', + logger=None, + imgfile_prefix=None, + efficient_test=False): + """Evaluation in Cityscapes/default protocol. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | None | str): Logger used for printing + related information during evaluation. Default: None. + imgfile_prefix (str | None): The prefix of output image file, + for cityscapes evaluation only. It includes the file path and + the prefix of filename, e.g., "a/b/prefix". + If results are evaluated with cityscapes protocol, it would be + the prefix of output png files. The output files would be + png images under folder "a/b/prefix/xxx.png", where "xxx" is + the image name of cityscapes. If not specified, a temp file + will be created for evaluation. + Default: None. + + Returns: + dict[str, float]: Cityscapes/default metrics. + """ + + eval_results = dict() + metrics = metric.copy() if isinstance(metric, list) else [metric] + if 'cityscapes' in metrics: + eval_results.update( + self._evaluate_cityscapes(results, logger, imgfile_prefix)) + metrics.remove('cityscapes') + if len(metrics) > 0: + eval_results.update( + super(CityscapesDataset, + self).evaluate(results, metrics, logger, efficient_test)) + + return eval_results + + def _evaluate_cityscapes(self, results, logger, imgfile_prefix): + """Evaluation in Cityscapes protocol. + + Args: + results (list): Testing results of the dataset. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + imgfile_prefix (str | None): The prefix of output image file + + Returns: + dict[str: float]: Cityscapes evaluation results. + """ + try: + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa + except ImportError: + raise ImportError('Please run "pip install cityscapesscripts" to ' + 'install cityscapesscripts first.') + msg = 'Evaluating in Cityscapes style' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + result_files, tmp_dir = self.format_results(results, imgfile_prefix) + + if tmp_dir is None: + result_dir = imgfile_prefix + else: + result_dir = tmp_dir.name + + eval_results = dict() + print_log(f'Evaluating results under {result_dir} ...', logger=logger) + + CSEval.args.evalInstLevelScore = True + CSEval.args.predictionPath = osp.abspath(result_dir) + CSEval.args.evalPixelAccuracy = True + CSEval.args.JSONOutput = False + + seg_map_list = [] + pred_list = [] + + # when evaluating with official cityscapesscripts, + # **_gtFine_labelIds.png is used + for seg_map in mmcv.scandir( + self.ann_dir, 'gtFine_labelIds.png', recursive=True): + seg_map_list.append(osp.join(self.ann_dir, seg_map)) + pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) + + eval_results.update( + CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) + + if tmp_dir is not None: + tmp_dir.cleanup() + + return eval_results diff --git a/annotator/uniformer/mmseg/datasets/custom.py b/annotator/uniformer/mmseg/datasets/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..d8eb2a709cc7a3a68fc6a1e3a1ad98faef4c5b7b --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/custom.py @@ -0,0 +1,400 @@ +import os +import os.path as osp +from collections import OrderedDict +from functools import reduce + +import annotator.uniformer.mmcv as mmcv +import numpy as np +from annotator.uniformer.mmcv.utils import print_log +from prettytable import PrettyTable +from torch.utils.data import Dataset + +from annotator.uniformer.mmseg.core import eval_metrics +from annotator.uniformer.mmseg.utils import get_root_logger +from .builder import DATASETS +from .pipelines import Compose + + +@DATASETS.register_module() +class CustomDataset(Dataset): + """Custom dataset for semantic segmentation. An example of file structure + is as followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The img/gt_semantic_seg pair of CustomDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/tutorials/new_dataset.md`` for more details. + + + Args: + pipeline (list[dict]): Processing pipeline + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. Default: '.jpg' + ann_dir (str, optional): Path to annotation directory. Default: None + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + split (str, optional): Split txt file. If split is specified, only + file with suffix in the splits will be loaded. Otherwise, all + images in img_dir/ann_dir will be loaded. Default: None + data_root (str, optional): Data root for img_dir/ann_dir. Default: + None. + test_mode (bool): If test_mode=True, gt wouldn't be loaded. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default: False + classes (str | Sequence[str], optional): Specify classes to load. + If is None, ``cls.CLASSES`` will be used. Default: None. + palette (Sequence[Sequence[int]]] | np.ndarray | None): + The palette of segmentation map. If None is given, and + self.PALETTE is None, random palette will be generated. + Default: None + """ + + CLASSES = None + + PALETTE = None + + def __init__(self, + pipeline, + img_dir, + img_suffix='.jpg', + ann_dir=None, + seg_map_suffix='.png', + split=None, + data_root=None, + test_mode=False, + ignore_index=255, + reduce_zero_label=False, + classes=None, + palette=None): + self.pipeline = Compose(pipeline) + self.img_dir = img_dir + self.img_suffix = img_suffix + self.ann_dir = ann_dir + self.seg_map_suffix = seg_map_suffix + self.split = split + self.data_root = data_root + self.test_mode = test_mode + self.ignore_index = ignore_index + self.reduce_zero_label = reduce_zero_label + self.label_map = None + self.CLASSES, self.PALETTE = self.get_classes_and_palette( + classes, palette) + + # join paths if data_root is specified + if self.data_root is not None: + if not osp.isabs(self.img_dir): + self.img_dir = osp.join(self.data_root, self.img_dir) + if not (self.ann_dir is None or osp.isabs(self.ann_dir)): + self.ann_dir = osp.join(self.data_root, self.ann_dir) + if not (self.split is None or osp.isabs(self.split)): + self.split = osp.join(self.data_root, self.split) + + # load annotations + self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, + self.ann_dir, + self.seg_map_suffix, self.split) + + def __len__(self): + """Total number of samples of data.""" + return len(self.img_infos) + + def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, + split): + """Load annotation from directory. + + Args: + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. + ann_dir (str|None): Path to annotation directory. + seg_map_suffix (str|None): Suffix of segmentation maps. + split (str|None): Split txt file. If split is specified, only file + with suffix in the splits will be loaded. Otherwise, all images + in img_dir/ann_dir will be loaded. Default: None + + Returns: + list[dict]: All image info of dataset. + """ + + img_infos = [] + if split is not None: + with open(split) as f: + for line in f: + img_name = line.strip() + img_info = dict(filename=img_name + img_suffix) + if ann_dir is not None: + seg_map = img_name + seg_map_suffix + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + else: + for img in mmcv.scandir(img_dir, img_suffix, recursive=True): + img_info = dict(filename=img) + if ann_dir is not None: + seg_map = img.replace(img_suffix, seg_map_suffix) + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + + print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) + return img_infos + + def get_ann_info(self, idx): + """Get annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + return self.img_infos[idx]['ann'] + + def pre_pipeline(self, results): + """Prepare results dict for pipeline.""" + results['seg_fields'] = [] + results['img_prefix'] = self.img_dir + results['seg_prefix'] = self.ann_dir + if self.custom_classes: + results['label_map'] = self.label_map + + def __getitem__(self, idx): + """Get training/test data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training/test data (with annotation if `test_mode` is set + False). + """ + + if self.test_mode: + return self.prepare_test_img(idx) + else: + return self.prepare_train_img(idx) + + def prepare_train_img(self, idx): + """Get training data and annotations after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys + introduced by pipeline. + """ + + img_info = self.img_infos[idx] + ann_info = self.get_ann_info(idx) + results = dict(img_info=img_info, ann_info=ann_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def prepare_test_img(self, idx): + """Get testing data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Testing data after pipeline with new keys introduced by + pipeline. + """ + + img_info = self.img_infos[idx] + results = dict(img_info=img_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def format_results(self, results, **kwargs): + """Place holder to format result to dataset specific output.""" + + def get_gt_seg_maps(self, efficient_test=False): + """Get ground truth segmentation maps for evaluation.""" + gt_seg_maps = [] + for img_info in self.img_infos: + seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) + if efficient_test: + gt_seg_map = seg_map + else: + gt_seg_map = mmcv.imread( + seg_map, flag='unchanged', backend='pillow') + gt_seg_maps.append(gt_seg_map) + return gt_seg_maps + + def get_classes_and_palette(self, classes=None, palette=None): + """Get class names of current dataset. + + Args: + classes (Sequence[str] | str | None): If classes is None, use + default CLASSES defined by builtin dataset. If classes is a + string, take it as a file name. The file contains the name of + classes where each line contains one class name. If classes is + a tuple or list, override the CLASSES defined by the dataset. + palette (Sequence[Sequence[int]]] | np.ndarray | None): + The palette of segmentation map. If None is given, random + palette will be generated. Default: None + """ + if classes is None: + self.custom_classes = False + return self.CLASSES, self.PALETTE + + self.custom_classes = True + if isinstance(classes, str): + # take it as a file path + class_names = mmcv.list_from_file(classes) + elif isinstance(classes, (tuple, list)): + class_names = classes + else: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + if self.CLASSES: + if not set(classes).issubset(self.CLASSES): + raise ValueError('classes is not a subset of CLASSES.') + + # dictionary, its keys are the old label ids and its values + # are the new label ids. + # used for changing pixel labels in load_annotations. + self.label_map = {} + for i, c in enumerate(self.CLASSES): + if c not in class_names: + self.label_map[i] = -1 + else: + self.label_map[i] = classes.index(c) + + palette = self.get_palette_for_custom_classes(class_names, palette) + + return class_names, palette + + def get_palette_for_custom_classes(self, class_names, palette=None): + + if self.label_map is not None: + # return subset of palette + palette = [] + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != -1: + palette.append(self.PALETTE[old_id]) + palette = type(self.PALETTE)(palette) + + elif palette is None: + if self.PALETTE is None: + palette = np.random.randint(0, 255, size=(len(class_names), 3)) + else: + palette = self.PALETTE + + return palette + + def evaluate(self, + results, + metric='mIoU', + logger=None, + efficient_test=False, + **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. 'mIoU', + 'mDice' and 'mFscore' are supported. + logger (logging.Logger | None | str): Logger used for printing + related information during evaluation. Default: None. + + Returns: + dict[str, float]: Default metrics. + """ + + if isinstance(metric, str): + metric = [metric] + allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + if not set(metric).issubset(set(allowed_metrics)): + raise KeyError('metric {} is not supported'.format(metric)) + eval_results = {} + gt_seg_maps = self.get_gt_seg_maps(efficient_test) + if self.CLASSES is None: + num_classes = len( + reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) + else: + num_classes = len(self.CLASSES) + ret_metrics = eval_metrics( + results, + gt_seg_maps, + num_classes, + self.ignore_index, + metric, + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label) + + if self.CLASSES is None: + class_names = tuple(range(num_classes)) + else: + class_names = self.CLASSES + + # summary table + ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + + # each class table + ret_metrics.pop('aAcc', None) + ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + ret_metrics_class.update({'Class': class_names}) + ret_metrics_class.move_to_end('Class', last=False) + + # for logger + class_table_data = PrettyTable() + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + + summary_table_data = PrettyTable() + for key, val in ret_metrics_summary.items(): + if key == 'aAcc': + summary_table_data.add_column(key, [val]) + else: + summary_table_data.add_column('m' + key, [val]) + + print_log('per class results:', logger) + print_log('\n' + class_table_data.get_string(), logger=logger) + print_log('Summary:', logger) + print_log('\n' + summary_table_data.get_string(), logger=logger) + + # each metric dict + for key, value in ret_metrics_summary.items(): + if key == 'aAcc': + eval_results[key] = value / 100.0 + else: + eval_results['m' + key] = value / 100.0 + + ret_metrics_class.pop('Class', None) + for key, value in ret_metrics_class.items(): + eval_results.update({ + key + '.' + str(name): value[idx] / 100.0 + for idx, name in enumerate(class_names) + }) + + if mmcv.is_list_of(results, str): + for file_name in results: + os.remove(file_name) + return eval_results diff --git a/annotator/uniformer/mmseg/datasets/dataset_wrappers.py b/annotator/uniformer/mmseg/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a5e957ec3b44465432617cf6e8f0b86a8a5efa --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/dataset_wrappers.py @@ -0,0 +1,50 @@ +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + +from .builder import DATASETS + + +@DATASETS.register_module() +class ConcatDataset(_ConcatDataset): + """A wrapper of concatenated dataset. + + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but + concat the group flag for image aspect ratio. + + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. + """ + + def __init__(self, datasets): + super(ConcatDataset, self).__init__(datasets) + self.CLASSES = datasets[0].CLASSES + self.PALETTE = datasets[0].PALETTE + + +@DATASETS.register_module() +class RepeatDataset(object): + """A wrapper of repeated dataset. + + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + + Args: + dataset (:obj:`Dataset`): The dataset to be repeated. + times (int): Repeat times. + """ + + def __init__(self, dataset, times): + self.dataset = dataset + self.times = times + self.CLASSES = dataset.CLASSES + self.PALETTE = dataset.PALETTE + self._ori_len = len(self.dataset) + + def __getitem__(self, idx): + """Get item from original dataset.""" + return self.dataset[idx % self._ori_len] + + def __len__(self): + """The length is multiplied by ``times``""" + return self.times * self._ori_len diff --git a/annotator/uniformer/mmseg/datasets/drive.py b/annotator/uniformer/mmseg/datasets/drive.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbfda8ae74bdf26c5aef197ff2866a7c7ad0cfd --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/drive.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class DRIVEDataset(CustomDataset): + """DRIVE dataset. + + In segmentation map annotation for DRIVE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_manual1.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(DRIVEDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='_manual1.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/annotator/uniformer/mmseg/datasets/hrf.py b/annotator/uniformer/mmseg/datasets/hrf.py new file mode 100644 index 0000000000000000000000000000000000000000..923203b51377f9344277fc561803d7a78bd2c684 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/hrf.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class HRFDataset(CustomDataset): + """HRF dataset. + + In segmentation map annotation for HRF, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(HRFDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/annotator/uniformer/mmseg/datasets/pascal_context.py b/annotator/uniformer/mmseg/datasets/pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..541a63c66a13fb16fd52921e755715ad8d078fdd --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pascal_context.py @@ -0,0 +1,103 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class PascalContextDataset(CustomDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext, 0 stands for background, + which is included in 60 categories. ``reduce_zero_label`` is fixed to + False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png'. + + Args: + split (str): Split txt file for PascalContext. + """ + + CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', + 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus', + 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', + 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', + 'floor', 'flower', 'food', 'grass', 'ground', 'horse', + 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', + 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', + 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', + 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', + 'window', 'wood') + + PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] + + def __init__(self, split, **kwargs): + super(PascalContextDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + split=split, + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) and self.split is not None + + +@DATASETS.register_module() +class PascalContextDataset59(CustomDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext, 0 stands for background, + which is included in 60 categories. ``reduce_zero_label`` is fixed to + False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png'. + + Args: + split (str): Split txt file for PascalContext. + """ + + CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', + 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', + 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', + 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', + 'food', 'grass', 'ground', 'horse', 'keyboard', 'light', + 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', + 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', + 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', + 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood') + + PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], + [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], + [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], + [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], + [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], + [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], + [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], + [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], + [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], + [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], + [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], + [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], + [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], + [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], + [0, 235, 255], [0, 173, 255], [31, 0, 255]] + + def __init__(self, split, **kwargs): + super(PascalContextDataset59, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + split=split, + reduce_zero_label=True, + **kwargs) + assert osp.exists(self.img_dir) and self.split is not None diff --git a/annotator/uniformer/mmseg/datasets/pipelines/__init__.py b/annotator/uniformer/mmseg/datasets/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9046b07bb4ddea7a707a392b42e72db7c9df67 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/__init__.py @@ -0,0 +1,16 @@ +from .compose import Compose +from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, + Transpose, to_tensor) +from .loading import LoadAnnotations, LoadImageFromFile +from .test_time_aug import MultiScaleFlipAug +from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, + PhotoMetricDistortion, RandomCrop, RandomFlip, + RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) + +__all__ = [ + 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', + 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', + 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', + 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', + 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' +] diff --git a/annotator/uniformer/mmseg/datasets/pipelines/compose.py b/annotator/uniformer/mmseg/datasets/pipelines/compose.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfcbb925c6d4ebf849328b9f94ef6fc24359bf5 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/compose.py @@ -0,0 +1,51 @@ +import collections + +from annotator.uniformer.mmcv.utils import build_from_cfg + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Compose(object): + """Compose multiple transforms sequentially. + + Args: + transforms (Sequence[dict | callable]): Sequence of transform object or + config dict to be composed. + """ + + def __init__(self, transforms): + assert isinstance(transforms, collections.abc.Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + transform = build_from_cfg(transform, PIPELINES) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict') + + def __call__(self, data): + """Call function to apply transforms sequentially. + + Args: + data (dict): A result dict contains the data to transform. + + Returns: + dict: Transformed data. + """ + + for t in self.transforms: + data = t(data) + if data is None: + return None + return data + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += f' {t}' + format_string += '\n)' + return format_string diff --git a/annotator/uniformer/mmseg/datasets/pipelines/formating.py b/annotator/uniformer/mmseg/datasets/pipelines/formating.py new file mode 100644 index 0000000000000000000000000000000000000000..97db85f4f9db39fb86ba77ead7d1a8407d810adb --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/formating.py @@ -0,0 +1,288 @@ +from collections.abc import Sequence + +import annotator.uniformer.mmcv as mmcv +import numpy as np +import torch +from annotator.uniformer.mmcv.parallel import DataContainer as DC + +from ..builder import PIPELINES + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@PIPELINES.register_module() +class ToTensor(object): + """Convert some results to :obj:`torch.Tensor` by given keys. + + Args: + keys (Sequence[str]): Keys that need to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert data in results to :obj:`torch.Tensor`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted + to :obj:`torch.Tensor`. + """ + + for key in self.keys: + results[key] = to_tensor(results[key]) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class ImageToTensor(object): + """Convert image to :obj:`torch.Tensor` by given keys. + + The dimension order of input image is (H, W, C). The pipeline will convert + it to (C, H, W). If only 2 dimension (H, W) is given, the output would be + (1, H, W). + + Args: + keys (Sequence[str]): Key of images to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert image in results to :obj:`torch.Tensor` and + transpose the channel order. + + Args: + results (dict): Result dict contains the image data to convert. + + Returns: + dict: The result dict contains the image converted + to :obj:`torch.Tensor` and transposed to (C, H, W) order. + """ + + for key in self.keys: + img = results[key] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + results[key] = to_tensor(img.transpose(2, 0, 1)) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class Transpose(object): + """Transpose some results by given keys. + + Args: + keys (Sequence[str]): Keys of results to be transposed. + order (Sequence[int]): Order of transpose. + """ + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def __call__(self, results): + """Call function to convert image in results to :obj:`torch.Tensor` and + transpose the channel order. + + Args: + results (dict): Result dict contains the image data to convert. + + Returns: + dict: The result dict contains the image converted + to :obj:`torch.Tensor` and transposed to (C, H, W) order. + """ + + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@PIPELINES.register_module() +class ToDataContainer(object): + """Convert results to :obj:`mmcv.DataContainer` by given fields. + + Args: + fields (Sequence[dict]): Each field is a dict like + ``dict(key='xxx', **kwargs)``. The ``key`` in result will + be converted to :obj:`mmcv.DataContainer` with ``**kwargs``. + Default: ``(dict(key='img', stack=True), + dict(key='gt_semantic_seg'))``. + """ + + def __init__(self, + fields=(dict(key='img', + stack=True), dict(key='gt_semantic_seg'))): + self.fields = fields + + def __call__(self, results): + """Call function to convert data in results to + :obj:`mmcv.DataContainer`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted to + :obj:`mmcv.DataContainer`. + """ + + for field in self.fields: + field = field.copy() + key = field.pop('key') + results[key] = DC(results[key], **field) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(fields={self.fields})' + + +@PIPELINES.register_module() +class DefaultFormatBundle(object): + """Default formatting bundle. + + It simplifies the pipeline of formatting common fields, including "img" + and "gt_semantic_seg". These fields are formatted as follows. + + - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) + - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, + (3)to DataContainer (stack=True) + """ + + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with + default bundle. + """ + + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + results['img'] = DC(to_tensor(img), stack=True) + if 'gt_semantic_seg' in results: + # convert to long + results['gt_semantic_seg'] = DC( + to_tensor(results['gt_semantic_seg'][None, + ...].astype(np.int64)), + stack=True) + return results + + def __repr__(self): + return self.__class__.__name__ + + +@PIPELINES.register_module() +class Collect(object): + """Collect data from the loader relevant to the specific task. + + This is usually the last stage of the data loader pipeline. Typically keys + is set to some subset of "img", "gt_semantic_seg". + + The "img_meta" item is always populated. The contents of the "img_meta" + dictionary depends on "meta_keys". By default this includes: + + - "img_shape": shape of the image input to the network as a tuple + (h, w, c). Note that images may be zero padded on the bottom/right + if the batch tensor is larger than this shape. + + - "scale_factor": a float indicating the preprocessing scale + + - "flip": a boolean indicating if image flip transform was used + + - "filename": path to the image file + + - "ori_shape": original shape of the image as a tuple (h, w, c) + + - "pad_shape": image shape after padding + + - "img_norm_cfg": a dict of normalization information: + - mean - per channel mean subtraction + - std - per channel std divisor + - to_rgb - bool indicating if bgr was converted to rgb + + Args: + keys (Sequence[str]): Keys of results to be collected in ``data``. + meta_keys (Sequence[str], optional): Meta keys to be converted to + ``mmcv.DataContainer`` and collected in ``data[img_metas]``. + Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'img_norm_cfg')`` + """ + + def __init__(self, + keys, + meta_keys=('filename', 'ori_filename', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction', 'img_norm_cfg')): + self.keys = keys + self.meta_keys = meta_keys + + def __call__(self, results): + """Call function to collect keys in results. The keys in ``meta_keys`` + will be converted to :obj:mmcv.DataContainer. + + Args: + results (dict): Result dict contains the data to collect. + + Returns: + dict: The result dict contains the following keys + - keys in``self.keys`` + - ``img_metas`` + """ + + data = {} + img_meta = {} + for key in self.meta_keys: + img_meta[key] = results[key] + data['img_metas'] = DC(img_meta, cpu_only=True) + for key in self.keys: + data[key] = results[key] + return data + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, meta_keys={self.meta_keys})' diff --git a/annotator/uniformer/mmseg/datasets/pipelines/loading.py b/annotator/uniformer/mmseg/datasets/pipelines/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..d3692ae91f19b9c7ccf6023168788ff42c9e93e3 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/loading.py @@ -0,0 +1,153 @@ +import os.path as osp + +import annotator.uniformer.mmcv as mmcv +import numpy as np + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class LoadImageFromFile(object): + """Load an image from file. + + Required keys are "img_prefix" and "img_info" (a dict that must contain the + key "filename"). Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:`mmcv.imfrombytes`. + Defaults to 'color'. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: + 'cv2' + """ + + def __init__(self, + to_float32=False, + color_type='color', + file_client_args=dict(backend='disk'), + imdecode_backend='cv2'): + self.to_float32 = to_float32 + self.color_type = color_type + self.file_client_args = file_client_args.copy() + self.file_client = None + self.imdecode_backend = imdecode_backend + + def __call__(self, results): + """Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if results.get('img_prefix') is not None: + filename = osp.join(results['img_prefix'], + results['img_info']['filename']) + else: + filename = results['img_info']['filename'] + img_bytes = self.file_client.get(filename) + img = mmcv.imfrombytes( + img_bytes, flag=self.color_type, backend=self.imdecode_backend) + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = filename + results['ori_filename'] = results['img_info']['filename'] + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results['img_norm_cfg'] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(to_float32={self.to_float32},' + repr_str += f"color_type='{self.color_type}'," + repr_str += f"imdecode_backend='{self.imdecode_backend}')" + return repr_str + + +@PIPELINES.register_module() +class LoadAnnotations(object): + """Load annotations for semantic segmentation. + + Args: + reduce_zero_label (bool): Whether reduce all label value by 1. + Usually used for datasets where 0 is background label. + Default: False. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: + 'pillow' + """ + + def __init__(self, + reduce_zero_label=False, + file_client_args=dict(backend='disk'), + imdecode_backend='pillow'): + self.reduce_zero_label = reduce_zero_label + self.file_client_args = file_client_args.copy() + self.file_client = None + self.imdecode_backend = imdecode_backend + + def __call__(self, results): + """Call function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if results.get('seg_prefix', None) is not None: + filename = osp.join(results['seg_prefix'], + results['ann_info']['seg_map']) + else: + filename = results['ann_info']['seg_map'] + img_bytes = self.file_client.get(filename) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='unchanged', + backend=self.imdecode_backend).squeeze().astype(np.uint8) + # modify if custom classes + if results.get('label_map', None) is not None: + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg == old_id] = new_id + # reduce zero_label + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + results['gt_semantic_seg'] = gt_semantic_seg + results['seg_fields'].append('gt_semantic_seg') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(reduce_zero_label={self.reduce_zero_label},' + repr_str += f"imdecode_backend='{self.imdecode_backend}')" + return repr_str diff --git a/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py b/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1611a04d9d927223c9afbe5bf68af04d62937a --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py @@ -0,0 +1,133 @@ +import warnings + +import annotator.uniformer.mmcv as mmcv + +from ..builder import PIPELINES +from .compose import Compose + + +@PIPELINES.register_module() +class MultiScaleFlipAug(object): + """Test-time augmentation with multiple scales and flipping. + + An example configuration is as followed: + + .. code-block:: + + img_scale=(2048, 1024), + img_ratios=[0.5, 1.0], + flip=True, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ] + + After MultiScaleFLipAug with above configuration, the results are wrapped + into lists of the same length as followed: + + .. code-block:: + + dict( + img=[...], + img_shape=[...], + scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] + flip=[False, True, False, True] + ... + ) + + Args: + transforms (list[dict]): Transforms to apply in each augmentation. + img_scale (None | tuple | list[tuple]): Images scales for resizing. + img_ratios (float | list[float]): Image ratios for resizing + flip (bool): Whether apply flip augmentation. Default: False. + flip_direction (str | list[str]): Flip augmentation directions, + options are "horizontal" and "vertical". If flip_direction is list, + multiple flip augmentations will be applied. + It has no effect when flip == False. Default: "horizontal". + """ + + def __init__(self, + transforms, + img_scale, + img_ratios=None, + flip=False, + flip_direction='horizontal'): + self.transforms = Compose(transforms) + if img_ratios is not None: + img_ratios = img_ratios if isinstance(img_ratios, + list) else [img_ratios] + assert mmcv.is_list_of(img_ratios, float) + if img_scale is None: + # mode 1: given img_scale=None and a range of image ratio + self.img_scale = None + assert mmcv.is_list_of(img_ratios, float) + elif isinstance(img_scale, tuple) and mmcv.is_list_of( + img_ratios, float): + assert len(img_scale) == 2 + # mode 2: given a scale and a range of image ratio + self.img_scale = [(int(img_scale[0] * ratio), + int(img_scale[1] * ratio)) + for ratio in img_ratios] + else: + # mode 3: given multiple scales + self.img_scale = img_scale if isinstance(img_scale, + list) else [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None + self.flip = flip + self.img_ratios = img_ratios + self.flip_direction = flip_direction if isinstance( + flip_direction, list) else [flip_direction] + assert mmcv.is_list_of(self.flip_direction, str) + if not self.flip and self.flip_direction != ['horizontal']: + warnings.warn( + 'flip_direction has no effect when flip is set to False') + if (self.flip + and not any([t['type'] == 'RandomFlip' for t in transforms])): + warnings.warn( + 'flip has no effect when RandomFlip is not in transforms') + + def __call__(self, results): + """Call function to apply test time augment transforms on results. + + Args: + results (dict): Result dict contains the data to transform. + + Returns: + dict[str: list]: The augmented data, where each value is wrapped + into a list. + """ + + aug_data = [] + if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): + h, w = results['img'].shape[:2] + img_scale = [(int(w * ratio), int(h * ratio)) + for ratio in self.img_ratios] + else: + img_scale = self.img_scale + flip_aug = [False, True] if self.flip else [False] + for scale in img_scale: + for flip in flip_aug: + for direction in self.flip_direction: + _results = results.copy() + _results['scale'] = scale + _results['flip'] = flip + _results['flip_direction'] = direction + data = self.transforms(_results) + aug_data.append(data) + # list of dict to dict of list + aug_data_dict = {key: [] for key in aug_data[0]} + for data in aug_data: + for key, val in data.items(): + aug_data_dict[key].append(val) + return aug_data_dict + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(transforms={self.transforms}, ' + repr_str += f'img_scale={self.img_scale}, flip={self.flip})' + repr_str += f'flip_direction={self.flip_direction}' + return repr_str diff --git a/annotator/uniformer/mmseg/datasets/pipelines/transforms.py b/annotator/uniformer/mmseg/datasets/pipelines/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..94e869b252ef6d8b43604add2bbc02f034614bfb --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/pipelines/transforms.py @@ -0,0 +1,889 @@ +import annotator.uniformer.mmcv as mmcv +import numpy as np +from annotator.uniformer.mmcv.utils import deprecated_api_warning, is_tuple_of +from numpy import random + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Resize(object): + """Resize images & seg. + + This transform resizes the input image to some scale. If the input dict + contains the key "scale", then the scale in the input dict is used, + otherwise the specified scale in the init method is used. + + ``img_scale`` can be None, a tuple (single-scale) or a list of tuple + (multi-scale). There are 4 multiscale modes: + + - ``ratio_range is not None``: + 1. When img_scale is None, img_scale is the shape of image in results + (img_scale = results['img'].shape[:2]) and the image is resized based + on the original size. (mode 1) + 2. When img_scale is a tuple (single-scale), randomly sample a ratio from + the ratio range and multiply it with the image scale. (mode 2) + + - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a + scale from the a range. (mode 3) + + - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a + scale from multiple scales. (mode 4) + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + multiscale_mode (str): Either "range" or "value". + ratio_range (tuple[float]): (min_ratio, max_ratio) + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given img_scale=None and a range of image ratio + # mode 2: given a scale and a range of image ratio + assert self.img_scale is None or len(self.img_scale) == 1 + else: + # mode 3 and 4: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, + where ``img_scale`` is the selected image scale and + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and upper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where + ``img_scale`` is sampled scale and None is just a placeholder + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where + ``scale`` is sampled ratio multiplied with ``img_scale`` and + None is just a placeholder to be consistent with + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + if self.img_scale is None: + h, w = results['img'].shape[:2] + scale, scale_idx = self.random_sample_ratio((w, h), + self.ratio_range) + else: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + if self.keep_ratio: + img, scale_factor = mmcv.imrescale( + results['img'], results['scale'], return_scale=True) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results['img'].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + results['img'], results['scale'], return_scale=True) + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape # in case that there is no padding + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[key], results['scale'], interpolation='nearest') + else: + gt_seg = mmcv.imresize( + results[key], results['scale'], interpolation='nearest') + results[key] = gt_seg + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + self._random_scale(results) + self._resize_img(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(img_scale={self.img_scale}, ' + f'multiscale_mode={self.multiscale_mode}, ' + f'ratio_range={self.ratio_range}, ' + f'keep_ratio={self.keep_ratio})') + return repr_str + + +@PIPELINES.register_module() +class RandomFlip(object): + """Flip the image & seg. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + Args: + prob (float, optional): The flipping probability. Default: None. + direction(str, optional): The flipping direction. Options are + 'horizontal' and 'vertical'. Default: 'horizontal'. + """ + + @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip') + def __init__(self, prob=None, direction='horizontal'): + self.prob = prob + self.direction = direction + if prob is not None: + assert prob >= 0 and prob <= 1 + assert direction in ['horizontal', 'vertical'] + + def __call__(self, results): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'flip', 'flip_direction' keys are added into + result dict. + """ + + if 'flip' not in results: + flip = True if np.random.rand() < self.prob else False + results['flip'] = flip + if 'flip_direction' not in results: + results['flip_direction'] = self.direction + if results['flip']: + # flip image + results['img'] = mmcv.imflip( + results['img'], direction=results['flip_direction']) + + # flip segs + for key in results.get('seg_fields', []): + # use copy() to make numpy stride positive + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']).copy() + return results + + def __repr__(self): + return self.__class__.__name__ + f'(prob={self.prob})' + + +@PIPELINES.register_module() +class Pad(object): + """Pad the image & mask. + + There are two padding modes: (1) pad to a fixed size and (2) pad to the + minimum size that is divisible by some number. + Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", + + Args: + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (float, optional): Padding value. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + """ + + def __init__(self, + size=None, + size_divisor=None, + pad_val=0, + seg_pad_val=255): + self.size = size + self.size_divisor = size_divisor + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + # only one of size and size_divisor should be valid + assert size is not None or size_divisor is not None + assert size is None or size_divisor is None + + def _pad_img(self, results): + """Pad images according to ``self.size``.""" + if self.size is not None: + padded_img = mmcv.impad( + results['img'], shape=self.size, pad_val=self.pad_val) + elif self.size_divisor is not None: + padded_img = mmcv.impad_to_multiple( + results['img'], self.size_divisor, pad_val=self.pad_val) + results['img'] = padded_img + results['pad_shape'] = padded_img.shape + results['pad_fixed_size'] = self.size + results['pad_size_divisor'] = self.size_divisor + + def _pad_seg(self, results): + """Pad masks according to ``results['pad_shape']``.""" + for key in results.get('seg_fields', []): + results[key] = mmcv.impad( + results[key], + shape=results['pad_shape'][:2], + pad_val=self.seg_pad_val) + + def __call__(self, results): + """Call function to pad images, masks, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + + self._pad_img(results) + self._pad_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \ + f'pad_val={self.pad_val})' + return repr_str + + +@PIPELINES.register_module() +class Normalize(object): + """Normalize the image. + + Added key is "img_norm_cfg". + + Args: + mean (sequence): Mean values of 3 channels. + std (sequence): Std values of 3 channels. + to_rgb (bool): Whether to convert the image from BGR to RGB, + default is true. + """ + + def __init__(self, mean, std, to_rgb=True): + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.to_rgb = to_rgb + + def __call__(self, results): + """Call function to normalize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Normalized results, 'img_norm_cfg' key is added into + result dict. + """ + + results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std, + self.to_rgb) + results['img_norm_cfg'] = dict( + mean=self.mean, std=self.std, to_rgb=self.to_rgb) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \ + f'{self.to_rgb})' + return repr_str + + +@PIPELINES.register_module() +class Rerange(object): + """Rerange the image pixel value. + + Args: + min_value (float or int): Minimum value of the reranged image. + Default: 0. + max_value (float or int): Maximum value of the reranged image. + Default: 255. + """ + + def __init__(self, min_value=0, max_value=255): + assert isinstance(min_value, float) or isinstance(min_value, int) + assert isinstance(max_value, float) or isinstance(max_value, int) + assert min_value < max_value + self.min_value = min_value + self.max_value = max_value + + def __call__(self, results): + """Call function to rerange images. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Reranged results. + """ + + img = results['img'] + img_min_value = np.min(img) + img_max_value = np.max(img) + + assert img_min_value < img_max_value + # rerange to [0, 1] + img = (img - img_min_value) / (img_max_value - img_min_value) + # rerange to [min_value, max_value] + img = img * (self.max_value - self.min_value) + self.min_value + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' + return repr_str + + +@PIPELINES.register_module() +class CLAHE(object): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Args: + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + """ + + def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)): + assert isinstance(clip_limit, (float, int)) + self.clip_limit = clip_limit + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + self.tile_grid_size = tile_grid_size + + def __call__(self, results): + """Call function to Use CLAHE method process images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + for i in range(results['img'].shape[2]): + results['img'][:, :, i] = mmcv.clahe( + np.array(results['img'][:, :, i], dtype=np.uint8), + self.clip_limit, self.tile_grid_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(clip_limit={self.clip_limit}, '\ + f'tile_grid_size={self.tile_grid_size})' + return repr_str + + +@PIPELINES.register_module() +class RandomCrop(object): + """Random crop the image & seg. + + Args: + crop_size (tuple): Expected size after cropping, (h, w). + cat_max_ratio (float): The maximum ratio that single category could + occupy. + """ + + def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255): + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + def get_crop_bbox(self, img): + """Randomly get a crop bounding box.""" + margin_h = max(img.shape[0] - self.crop_size[0], 0) + margin_w = max(img.shape[1] - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + def crop(self, img, crop_bbox): + """Crop from ``img``""" + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def __call__(self, results): + """Call function to randomly crop images, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + img = results['img'] + crop_bbox = self.get_crop_bbox(img) + if self.cat_max_ratio < 1.: + # Repeat 10 times + for _ in range(10): + seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox) + labels, cnt = np.unique(seg_temp, return_counts=True) + cnt = cnt[labels != self.ignore_index] + if len(cnt) > 1 and np.max(cnt) / np.sum( + cnt) < self.cat_max_ratio: + break + crop_bbox = self.get_crop_bbox(img) + + # crop the image + img = self.crop(img, crop_bbox) + img_shape = img.shape + results['img'] = img + results['img_shape'] = img_shape + + # crop semantic seg + for key in results.get('seg_fields', []): + results[key] = self.crop(results[key], crop_bbox) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +@PIPELINES.register_module() +class RandomRotate(object): + """Rotate the image & seg. + + Args: + prob (float): The rotation probability. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + pad_val (float, optional): Padding value of image. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. Default: None. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. Default: False + """ + + def __init__(self, + prob, + degree, + pad_val=0, + seg_pad_val=255, + center=None, + auto_bound=False): + self.prob = prob + assert prob >= 0 and prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + self.pal_val = pad_val + self.seg_pad_val = seg_pad_val + self.center = center + self.auto_bound = auto_bound + + def __call__(self, results): + """Call function to rotate image, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + + rotate = True if np.random.rand() < self.prob else False + degree = np.random.uniform(min(*self.degree), max(*self.degree)) + if rotate: + # rotate image + results['img'] = mmcv.imrotate( + results['img'], + angle=degree, + border_value=self.pal_val, + center=self.center, + auto_bound=self.auto_bound) + + # rotate segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate( + results[key], + angle=degree, + border_value=self.seg_pad_val, + center=self.center, + auto_bound=self.auto_bound, + interpolation='nearest') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' \ + f'degree={self.degree}, ' \ + f'pad_val={self.pal_val}, ' \ + f'seg_pad_val={self.seg_pad_val}, ' \ + f'center={self.center}, ' \ + f'auto_bound={self.auto_bound})' + return repr_str + + +@PIPELINES.register_module() +class RGB2Gray(object): + """Convert RGB image to grayscale image. + + This transform calculate the weighted mean of input image channels with + ``weights`` and then expand the channels to ``out_channels``. When + ``out_channels`` is None, the number of output channels is the same as + input channels. + + Args: + out_channels (int): Expected number of output channels after + transforming. Default: None. + weights (tuple[float]): The weights to calculate the weighted mean. + Default: (0.299, 0.587, 0.114). + """ + + def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)): + assert out_channels is None or out_channels > 0 + self.out_channels = out_channels + assert isinstance(weights, tuple) + for item in weights: + assert isinstance(item, (float, int)) + self.weights = weights + + def __call__(self, results): + """Call function to convert RGB image to grayscale image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with grayscale image. + """ + img = results['img'] + assert len(img.shape) == 3 + assert img.shape[2] == len(self.weights) + weights = np.array(self.weights).reshape((1, 1, -1)) + img = (img * weights).sum(2, keepdims=True) + if self.out_channels is None: + img = img.repeat(weights.shape[2], axis=2) + else: + img = img.repeat(self.out_channels, axis=2) + + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(out_channels={self.out_channels}, ' \ + f'weights={self.weights})' + return repr_str + + +@PIPELINES.register_module() +class AdjustGamma(object): + """Using gamma correction to process the image. + + Args: + gamma (float or int): Gamma value used in gamma correction. + Default: 1.0. + """ + + def __init__(self, gamma=1.0): + assert isinstance(gamma, float) or isinstance(gamma, int) + assert gamma > 0 + self.gamma = gamma + inv_gamma = 1.0 / gamma + self.table = np.array([(i / 255.0)**inv_gamma * 255 + for i in np.arange(256)]).astype('uint8') + + def __call__(self, results): + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + results['img'] = mmcv.lut_transform( + np.array(results['img'], dtype=np.uint8), self.table) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(gamma={self.gamma})' + + +@PIPELINES.register_module() +class SegRescale(object): + """Rescale semantic segmentation maps. + + Args: + scale_factor (float): The scale factor of the final output. + """ + + def __init__(self, scale_factor=1): + self.scale_factor = scale_factor + + def __call__(self, results): + """Call function to scale the semantic segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with semantic segmentation map scaled. + """ + for key in results.get('seg_fields', []): + if self.scale_factor != 1: + results[key] = mmcv.imrescale( + results[key], self.scale_factor, interpolation='nearest') + return results + + def __repr__(self): + return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' + + +@PIPELINES.register_module() +class PhotoMetricDistortion(object): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def convert(self, img, alpha=1, beta=0): + """Multiple with alpha and add beat with clip.""" + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img): + """Brightness distortion.""" + if random.randint(2): + return self.convert( + img, + beta=random.uniform(-self.brightness_delta, + self.brightness_delta)) + return img + + def contrast(self, img): + """Contrast distortion.""" + if random.randint(2): + return self.convert( + img, + alpha=random.uniform(self.contrast_lower, self.contrast_upper)) + return img + + def saturation(self, img): + """Saturation distortion.""" + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, 1] = self.convert( + img[:, :, 1], + alpha=random.uniform(self.saturation_lower, + self.saturation_upper)) + img = mmcv.hsv2bgr(img) + return img + + def hue(self, img): + """Hue distortion.""" + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, + 0] = (img[:, :, 0].astype(int) + + random.randint(-self.hue_delta, self.hue_delta)) % 180 + img = mmcv.hsv2bgr(img) + return img + + def __call__(self, results): + """Call function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + img = results['img'] + # random brightness + img = self.brightness(img) + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + img = self.contrast(img) + + # random saturation + img = self.saturation(img) + + # random hue + img = self.hue(img) + + # random contrast + if mode == 0: + img = self.contrast(img) + + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(brightness_delta={self.brightness_delta}, ' + f'contrast_range=({self.contrast_lower}, ' + f'{self.contrast_upper}), ' + f'saturation_range=({self.saturation_lower}, ' + f'{self.saturation_upper}), ' + f'hue_delta={self.hue_delta})') + return repr_str diff --git a/annotator/uniformer/mmseg/datasets/stare.py b/annotator/uniformer/mmseg/datasets/stare.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd14e0920e7f6a73baff1432e5a32ccfdb0dfae --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/stare.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class STAREDataset(CustomDataset): + """STARE dataset. + + In segmentation map annotation for STARE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.ah.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(STAREDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.ah.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/annotator/uniformer/mmseg/datasets/voc.py b/annotator/uniformer/mmseg/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..a8855203b14ee0dc4da9099a2945d4aedcffbcd6 --- /dev/null +++ b/annotator/uniformer/mmseg/datasets/voc.py @@ -0,0 +1,29 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class PascalVOCDataset(CustomDataset): + """Pascal VOC dataset. + + Args: + split (str): Split txt file for Pascal VOC. + """ + + CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', + 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', + 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', + 'train', 'tvmonitor') + + PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + def __init__(self, split, **kwargs): + super(PascalVOCDataset, self).__init__( + img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) + assert osp.exists(self.img_dir) and self.split is not None diff --git a/annotator/uniformer/mmseg/ops/__init__.py b/annotator/uniformer/mmseg/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bec51c75b9363a9a19e9fb5c35f4e7dbd6f7751c --- /dev/null +++ b/annotator/uniformer/mmseg/ops/__init__.py @@ -0,0 +1,4 @@ +from .encoding import Encoding +from .wrappers import Upsample, resize + +__all__ = ['Upsample', 'resize', 'Encoding'] diff --git a/annotator/uniformer/mmseg/ops/encoding.py b/annotator/uniformer/mmseg/ops/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..7eb3629a6426550b8e4c537ee1ff4341893e489e --- /dev/null +++ b/annotator/uniformer/mmseg/ops/encoding.py @@ -0,0 +1,74 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class Encoding(nn.Module): + """Encoding Layer: a learnable residual encoder. + + Input is of shape (batch_size, channels, height, width). + Output is of shape (batch_size, num_codes, channels). + + Args: + channels: dimension of the features or feature channels + num_codes: number of code words + """ + + def __init__(self, channels, num_codes): + super(Encoding, self).__init__() + # init codewords and smoothing factor + self.channels, self.num_codes = channels, num_codes + std = 1. / ((num_codes * channels)**0.5) + # [num_codes, channels] + self.codewords = nn.Parameter( + torch.empty(num_codes, channels, + dtype=torch.float).uniform_(-std, std), + requires_grad=True) + # [num_codes] + self.scale = nn.Parameter( + torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), + requires_grad=True) + + @staticmethod + def scaled_l2(x, codewords, scale): + num_codes, channels = codewords.size() + batch_size = x.size(0) + reshaped_scale = scale.view((1, 1, num_codes)) + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + + scaled_l2_norm = reshaped_scale * ( + expanded_x - reshaped_codewords).pow(2).sum(dim=3) + return scaled_l2_norm + + @staticmethod + def aggregate(assignment_weights, x, codewords): + num_codes, channels = codewords.size() + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + batch_size = x.size(0) + + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + encoded_feat = (assignment_weights.unsqueeze(3) * + (expanded_x - reshaped_codewords)).sum(dim=1) + return encoded_feat + + def forward(self, x): + assert x.dim() == 4 and x.size(1) == self.channels + # [batch_size, channels, height, width] + batch_size = x.size(0) + # [batch_size, height x width, channels] + x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() + # assignment_weights: [batch_size, channels, num_codes] + assignment_weights = F.softmax( + self.scaled_l2(x, self.codewords, self.scale), dim=2) + # aggregate + encoded_feat = self.aggregate(assignment_weights, x, self.codewords) + return encoded_feat + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ + f'x{self.channels})' + return repr_str diff --git a/annotator/uniformer/mmseg/ops/wrappers.py b/annotator/uniformer/mmseg/ops/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed9a0cb8d7c0e0ec2748dd89c652756653cac78 --- /dev/null +++ b/annotator/uniformer/mmseg/ops/wrappers.py @@ -0,0 +1,50 @@ +import warnings + +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super(Upsample, self).__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/annotator/uniformer/mmseg/utils/__init__.py b/annotator/uniformer/mmseg/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac489e2dbbc0e6fa87f5088b4edcc20f8cadc1a6 --- /dev/null +++ b/annotator/uniformer/mmseg/utils/__init__.py @@ -0,0 +1,4 @@ +from .collect_env import collect_env +from .logger import get_root_logger + +__all__ = ['get_root_logger', 'collect_env'] diff --git a/annotator/uniformer/mmseg/utils/collect_env.py b/annotator/uniformer/mmseg/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..65c2134ddbee9655161237dd0894d38c768c2624 --- /dev/null +++ b/annotator/uniformer/mmseg/utils/collect_env.py @@ -0,0 +1,17 @@ +from annotator.uniformer.mmcv.utils import collect_env as collect_base_env +from annotator.uniformer.mmcv.utils import get_git_hash + +import annotator.uniformer.mmseg as mmseg + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' + + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print('{}: {}'.format(name, val)) diff --git a/annotator/uniformer/mmseg/utils/logger.py b/annotator/uniformer/mmseg/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..4149d9eda3dfef07490352d22ac40c42460315e4 --- /dev/null +++ b/annotator/uniformer/mmseg/utils/logger.py @@ -0,0 +1,27 @@ +import logging + +from annotator.uniformer.mmcv.utils import get_logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmseg". + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + + logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) + + return logger diff --git a/annotator/util.py b/annotator/util.py new file mode 100644 index 0000000000000000000000000000000000000000..eecc843580f3f10e74df894245ecbfacc090e04b --- /dev/null +++ b/annotator/util.py @@ -0,0 +1,69 @@ +import numpy as np +import cv2 +import os +import torch + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + +def get_control(type): + if type == 'canny': + from .canny import CannyDetector + apply_control = CannyDetector() + elif type == 'openpose': + from .openpose import OpenposeDetector + apply_control = OpenposeDetector() + elif type == 'dwpose': + from .dwpose import DWposeDetector + apply_control = DWposeDetector() + elif type == 'depth' or type == 'normal': + from .midas import MidasDetector + apply_control = MidasDetector() + elif type == 'depth_zoe': + from .zoe import ZoeDetector + apply_control = ZoeDetector() + elif type == 'hed': + from .hed import HEDdetector + apply_control = HEDdetector() + elif type == 'scribble': + apply_control = None + elif type == 'seg': + from .uniformer import UniformerDetector + apply_control = UniformerDetector() + elif type == 'mlsd': + from .mlsd import MLSDdetector + apply_control = MLSDdetector() + else: + raise TypeError(type) + return apply_control + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img diff --git a/annotator/zoe/LICENSE b/annotator/zoe/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7a1e90d007836c327846ce8e5151013b115042ab --- /dev/null +++ b/annotator/zoe/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Intelligent Systems Lab Org + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/annotator/zoe/__init__.py b/annotator/zoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1a23c072d13c21e4122ad21d458e4a311b62a8 --- /dev/null +++ b/annotator/zoe/__init__.py @@ -0,0 +1,49 @@ +# ZoeDepth +# https://github.com/isl-org/ZoeDepth + +import os +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth +from .zoedepth.utils.config import get_config +from annotator.util import annotator_ckpts_path + + +class ZoeDetector: + def __init__(self): + remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt" + modelpath = os.path.join(annotator_ckpts_path, "ZoeD_M12_N.pt") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + conf = get_config("zoedepth", "infer") + model = ZoeDepth.build_from_config(conf) + model.load_state_dict(torch.load(modelpath)['model']) + model = model.cuda() + model.device = 'cuda' + model.eval() + self.model = model + + def __call__(self, input_image): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 255.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model.infer(image_depth) + + depth = depth[0, 0].cpu().numpy() + + vmin = np.percentile(depth, 2) + vmax = np.percentile(depth, 85) + + depth -= vmin + depth /= vmax - vmin + depth = 1.0 - depth + depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + return depth_image \ No newline at end of file diff --git a/annotator/zoe/__pycache__/__init__.cpython-310.pyc b/annotator/zoe/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31ef210fbed36848f56408097842dfaf90c433f1 Binary files /dev/null and b/annotator/zoe/__pycache__/__init__.cpython-310.pyc differ diff --git a/annotator/zoe/__pycache__/__init__.cpython-38.pyc b/annotator/zoe/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..520cecdd8f2efeb090442cdc1ad90922484e650e Binary files /dev/null and b/annotator/zoe/__pycache__/__init__.cpython-38.pyc differ diff --git a/annotator/zoe/__pycache__/__init__.cpython-39.pyc b/annotator/zoe/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31ff5ae4d76104e8586f0425efad03b7d3213ba5 Binary files /dev/null and b/annotator/zoe/__pycache__/__init__.cpython-39.pyc differ diff --git a/annotator/zoe/zoedepth/trainers/base_trainer.py b/annotator/zoe/zoedepth/trainers/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..33fbbea3a7d49efe11b005adb5127f441eabfaf6 --- /dev/null +++ b/annotator/zoe/zoedepth/trainers/base_trainer.py @@ -0,0 +1,326 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os +import uuid +import warnings +from datetime import datetime as dt +from typing import Dict + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +import wandb +from tqdm import tqdm + +from zoedepth.utils.config import flatten +from zoedepth.utils.misc import RunningAverageDict, colorize, colors + + +def is_rank_zero(args): + return args.rank == 0 + + +class BaseTrainer: + def __init__(self, config, model, train_loader, test_loader=None, device=None): + """ Base Trainer class for training a model.""" + + self.config = config + self.metric_criterion = "abs_rel" + if device is None: + device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + self.device = device + self.model = model + self.train_loader = train_loader + self.test_loader = test_loader + self.optimizer = self.init_optimizer() + self.scheduler = self.init_scheduler() + + def resize_to_target(self, prediction, target): + if prediction.shape[2:] != target.shape[-2:]: + prediction = nn.functional.interpolate( + prediction, size=target.shape[-2:], mode="bilinear", align_corners=True + ) + return prediction + + def load_ckpt(self, checkpoint_dir="./checkpoints", ckpt_type="best"): + import glob + import os + + from zoedepth.models.model_io import load_wts + + if hasattr(self.config, "checkpoint"): + checkpoint = self.config.checkpoint + elif hasattr(self.config, "ckpt_pattern"): + pattern = self.config.ckpt_pattern + matches = glob.glob(os.path.join( + checkpoint_dir, f"*{pattern}*{ckpt_type}*")) + if not (len(matches) > 0): + raise ValueError(f"No matches found for the pattern {pattern}") + checkpoint = matches[0] + else: + return + model = load_wts(self.model, checkpoint) + # TODO : Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it. + print("Loaded weights from {0}".format(checkpoint)) + warnings.warn( + "Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it.") + self.model = model + + def init_optimizer(self): + m = self.model.module if self.config.multigpu else self.model + + if self.config.same_lr: + print("Using same LR") + if hasattr(m, 'core'): + m.core.unfreeze() + params = self.model.parameters() + else: + print("Using diff LR") + if not hasattr(m, 'get_lr_params'): + raise NotImplementedError( + f"Model {m.__class__.__name__} does not implement get_lr_params. Please implement it or use the same LR for all parameters.") + + params = m.get_lr_params(self.config.lr) + + return optim.AdamW(params, lr=self.config.lr, weight_decay=self.config.wd) + + def init_scheduler(self): + lrs = [l['lr'] for l in self.optimizer.param_groups] + return optim.lr_scheduler.OneCycleLR(self.optimizer, lrs, epochs=self.config.epochs, steps_per_epoch=len(self.train_loader), + cycle_momentum=self.config.cycle_momentum, + base_momentum=0.85, max_momentum=0.95, div_factor=self.config.div_factor, final_div_factor=self.config.final_div_factor, pct_start=self.config.pct_start, three_phase=self.config.three_phase) + + def train_on_batch(self, batch, train_step): + raise NotImplementedError + + def validate_on_batch(self, batch, val_step): + raise NotImplementedError + + def raise_if_nan(self, losses): + for key, value in losses.items(): + if torch.isnan(value): + raise ValueError(f"{key} is NaN, Stopping training") + + @property + def iters_per_epoch(self): + return len(self.train_loader) + + @property + def total_iters(self): + return self.config.epochs * self.iters_per_epoch + + def should_early_stop(self): + if self.config.get('early_stop', False) and self.step > self.config.early_stop: + return True + + def train(self): + print(f"Training {self.config.name}") + if self.config.uid is None: + self.config.uid = str(uuid.uuid4()).split('-')[-1] + run_id = f"{dt.now().strftime('%d-%h_%H-%M')}-{self.config.uid}" + self.config.run_id = run_id + self.config.experiment_id = f"{self.config.name}{self.config.version_name}_{run_id}" + self.should_write = ((not self.config.distributed) + or self.config.rank == 0) + self.should_log = self.should_write # and logging + if self.should_log: + tags = self.config.tags.split( + ',') if self.config.tags != '' else None + wandb.init(project=self.config.project, name=self.config.experiment_id, config=flatten(self.config), dir=self.config.root, + tags=tags, notes=self.config.notes, settings=wandb.Settings(start_method="fork")) + + self.model.train() + self.step = 0 + best_loss = np.inf + validate_every = int(self.config.validate_every * self.iters_per_epoch) + + + if self.config.prefetch: + + for i, batch in tqdm(enumerate(self.train_loader), desc=f"Prefetching...", + total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader): + pass + + losses = {} + def stringify_losses(L): return "; ".join(map( + lambda kv: f"{colors.fg.purple}{kv[0]}{colors.reset}: {round(kv[1].item(),3):.4e}", L.items())) + for epoch in range(self.config.epochs): + if self.should_early_stop(): + break + + self.epoch = epoch + ################################# Train loop ########################################################## + if self.should_log: + wandb.log({"Epoch": epoch}, step=self.step) + pbar = tqdm(enumerate(self.train_loader), desc=f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train", + total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader) + for i, batch in pbar: + if self.should_early_stop(): + print("Early stopping") + break + # print(f"Batch {self.step+1} on rank {self.config.rank}") + losses = self.train_on_batch(batch, i) + # print(f"trained batch {self.step+1} on rank {self.config.rank}") + + self.raise_if_nan(losses) + if is_rank_zero(self.config) and self.config.print_losses: + pbar.set_description( + f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train. Losses: {stringify_losses(losses)}") + self.scheduler.step() + + if self.should_log and self.step % 50 == 0: + wandb.log({f"Train/{name}": loss.item() + for name, loss in losses.items()}, step=self.step) + + self.step += 1 + + ######################################################################################################## + + if self.test_loader: + if (self.step % validate_every) == 0: + self.model.eval() + if self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_latest.pt") + + ################################# Validation loop ################################################## + # validate on the entire validation set in every process but save only from rank 0, I know, inefficient, but avoids divergence of processes + metrics, test_losses = self.validate() + # print("Validated: {}".format(metrics)) + if self.should_log: + wandb.log( + {f"Test/{name}": tloss for name, tloss in test_losses.items()}, step=self.step) + + wandb.log({f"Metrics/{k}": v for k, + v in metrics.items()}, step=self.step) + + if (metrics[self.metric_criterion] < best_loss) and self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_best.pt") + best_loss = metrics[self.metric_criterion] + + self.model.train() + + if self.config.distributed: + dist.barrier() + # print(f"Validated: {metrics} on device {self.config.rank}") + + # print(f"Finished step {self.step} on device {self.config.rank}") + ################################################################################################# + + # Save / validate at the end + self.step += 1 # log as final point + self.model.eval() + self.save_checkpoint(f"{self.config.experiment_id}_latest.pt") + if self.test_loader: + + ################################# Validation loop ################################################## + metrics, test_losses = self.validate() + # print("Validated: {}".format(metrics)) + if self.should_log: + wandb.log({f"Test/{name}": tloss for name, + tloss in test_losses.items()}, step=self.step) + wandb.log({f"Metrics/{k}": v for k, + v in metrics.items()}, step=self.step) + + if (metrics[self.metric_criterion] < best_loss) and self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_best.pt") + best_loss = metrics[self.metric_criterion] + + self.model.train() + + def validate(self): + with torch.no_grad(): + losses_avg = RunningAverageDict() + metrics_avg = RunningAverageDict() + for i, batch in tqdm(enumerate(self.test_loader), desc=f"Epoch: {self.epoch + 1}/{self.config.epochs}. Loop: Validation", total=len(self.test_loader), disable=not is_rank_zero(self.config)): + metrics, losses = self.validate_on_batch(batch, val_step=i) + + if losses: + losses_avg.update(losses) + if metrics: + metrics_avg.update(metrics) + + return metrics_avg.get_value(), losses_avg.get_value() + + def save_checkpoint(self, filename): + if not self.should_write: + return + root = self.config.save_dir + if not os.path.isdir(root): + os.makedirs(root) + + fpath = os.path.join(root, filename) + m = self.model.module if self.config.multigpu else self.model + torch.save( + { + "model": m.state_dict(), + "optimizer": None, # TODO : Change to self.optimizer.state_dict() if resume support is needed, currently None to reduce file size + "epoch": self.epoch + }, fpath) + + def log_images(self, rgb: Dict[str, list] = {}, depth: Dict[str, list] = {}, scalar_field: Dict[str, list] = {}, prefix="", scalar_cmap="jet", min_depth=None, max_depth=None): + if not self.should_log: + return + + if min_depth is None: + try: + min_depth = self.config.min_depth + max_depth = self.config.max_depth + except AttributeError: + min_depth = None + max_depth = None + + depth = {k: colorize(v, vmin=min_depth, vmax=max_depth) + for k, v in depth.items()} + scalar_field = {k: colorize( + v, vmin=None, vmax=None, cmap=scalar_cmap) for k, v in scalar_field.items()} + images = {**rgb, **depth, **scalar_field} + wimages = { + prefix+"Predictions": [wandb.Image(v, caption=k) for k, v in images.items()]} + wandb.log(wimages, step=self.step) + + def log_line_plot(self, data): + if not self.should_log: + return + + plt.plot(data) + plt.ylabel("Scale factors") + wandb.log({"Scale factors": wandb.Image(plt)}, step=self.step) + plt.close() + + def log_bar_plot(self, title, labels, values): + if not self.should_log: + return + + data = [[label, val] for (label, val) in zip(labels, values)] + table = wandb.Table(data=data, columns=["label", "value"]) + wandb.log({title: wandb.plot.bar(table, "label", + "value", title=title)}, step=self.step) diff --git a/annotator/zoe/zoedepth/trainers/builder.py b/annotator/zoe/zoedepth/trainers/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a663541b08912ebedce21a68c7599ce4c06e85d0 --- /dev/null +++ b/annotator/zoe/zoedepth/trainers/builder.py @@ -0,0 +1,48 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module + + +def get_trainer(config): + """Builds and returns a trainer based on the config. + + Args: + config (dict): the config dict (typically constructed using utils.config.get_config) + config.trainer (str): the name of the trainer to use. The module named "{config.trainer}_trainer" must exist in trainers root module + + Raises: + ValueError: If the specified trainer does not exist under trainers/ folder + + Returns: + Trainer (inherited from zoedepth.trainers.BaseTrainer): The Trainer object + """ + assert "trainer" in config and config.trainer is not None and config.trainer != '', "Trainer not specified. Config: {0}".format( + config) + try: + Trainer = getattr(import_module( + f"zoedepth.trainers.{config.trainer}_trainer"), 'Trainer') + except ModuleNotFoundError as e: + raise ValueError(f"Trainer {config.trainer}_trainer not found.") from e + return Trainer diff --git a/annotator/zoe/zoedepth/trainers/loss.py b/annotator/zoe/zoedepth/trainers/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5a1c15cdf5628c1474c566fdc6e58159d7f5ab --- /dev/null +++ b/annotator/zoe/zoedepth/trainers/loss.py @@ -0,0 +1,316 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.cuda.amp as amp +import numpy as np + + +KEY_OUTPUT = 'metric_depth' + + +def extract_key(prediction, key): + if isinstance(prediction, dict): + return prediction[key] + return prediction + + +# Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7) +class SILogLoss(nn.Module): + """SILog loss (pixel-wise)""" + def __init__(self, beta=0.15): + super(SILogLoss, self).__init__() + self.name = 'SILog' + self.beta = beta + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + if target.ndim == 3: + target = target.unsqueeze(1) + + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + input = input[mask] + target = target[mask] + + with amp.autocast(enabled=False): # amp causes NaNs in this loss function + alpha = 1e-7 + g = torch.log(input + alpha) - torch.log(target + alpha) + + # n, c, h, w = g.shape + # norm = 1/(h*w) + # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2 + + Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2) + + loss = 10 * torch.sqrt(Dg) + + if torch.isnan(loss): + print("Nan SILog loss") + print("input:", input.shape) + print("target:", target.shape) + print("G", torch.sum(torch.isnan(g))) + print("Input min max", torch.min(input), torch.max(input)) + print("Target min max", torch.min(target), torch.max(target)) + print("Dg", torch.isnan(Dg)) + print("loss", torch.isnan(loss)) + + if not return_interpolated: + return loss + + return loss, intr_input + + +def grad(x): + # x.shape : n, c, h, w + diff_x = x[..., 1:, 1:] - x[..., 1:, :-1] + diff_y = x[..., 1:, 1:] - x[..., :-1, 1:] + mag = diff_x**2 + diff_y**2 + # angle_ratio + angle = torch.atan(diff_y / (diff_x + 1e-10)) + return mag, angle + + +def grad_mask(mask): + return mask[..., 1:, 1:] & mask[..., 1:, :-1] & mask[..., :-1, 1:] + + +class GradL1Loss(nn.Module): + """Gradient loss""" + def __init__(self): + super(GradL1Loss, self).__init__() + self.name = 'GradL1' + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + grad_gt = grad(target) + grad_pred = grad(input) + mask_g = grad_mask(mask) + + loss = nn.functional.l1_loss(grad_pred[0][mask_g], grad_gt[0][mask_g]) + loss = loss + \ + nn.functional.l1_loss(grad_pred[1][mask_g], grad_gt[1][mask_g]) + if not return_interpolated: + return loss + return loss, intr_input + + +class OrdinalRegressionLoss(object): + + def __init__(self, ord_num, beta, discretization="SID"): + self.ord_num = ord_num + self.beta = beta + self.discretization = discretization + + def _create_ord_label(self, gt): + N,one, H, W = gt.shape + # print("gt shape:", gt.shape) + + ord_c0 = torch.ones(N, self.ord_num, H, W).to(gt.device) + if self.discretization == "SID": + label = self.ord_num * torch.log(gt) / np.log(self.beta) + else: + label = self.ord_num * (gt - 1.0) / (self.beta - 1.0) + label = label.long() + mask = torch.linspace(0, self.ord_num - 1, self.ord_num, requires_grad=False) \ + .view(1, self.ord_num, 1, 1).to(gt.device) + mask = mask.repeat(N, 1, H, W).contiguous().long() + mask = (mask > label) + ord_c0[mask] = 0 + ord_c1 = 1 - ord_c0 + # implementation according to the paper. + # ord_label = torch.ones(N, self.ord_num * 2, H, W).to(gt.device) + # ord_label[:, 0::2, :, :] = ord_c0 + # ord_label[:, 1::2, :, :] = ord_c1 + # reimplementation for fast speed. + ord_label = torch.cat((ord_c0, ord_c1), dim=1) + return ord_label, mask + + def __call__(self, prob, gt): + """ + :param prob: ordinal regression probability, N x 2*Ord Num x H x W, torch.Tensor + :param gt: depth ground truth, NXHxW, torch.Tensor + :return: loss: loss value, torch.float + """ + # N, C, H, W = prob.shape + valid_mask = gt > 0. + ord_label, mask = self._create_ord_label(gt) + # print("prob shape: {}, ord label shape: {}".format(prob.shape, ord_label.shape)) + entropy = -prob * ord_label + loss = torch.sum(entropy, dim=1)[valid_mask.squeeze(1)] + return loss.mean() + + +class DiscreteNLLLoss(nn.Module): + """Cross entropy loss""" + def __init__(self, min_depth=1e-3, max_depth=10, depth_bins=64): + super(DiscreteNLLLoss, self).__init__() + self.name = 'CrossEntropy' + self.ignore_index = -(depth_bins + 1) + # self._loss_func = nn.NLLLoss(ignore_index=self.ignore_index) + self._loss_func = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + self.min_depth = min_depth + self.max_depth = max_depth + self.depth_bins = depth_bins + self.alpha = 1 + self.zeta = 1 - min_depth + self.beta = max_depth + self.zeta + + def quantize_depth(self, depth): + # depth : N1HW + # output : NCHW + + # Quantize depth log-uniformly on [1, self.beta] into self.depth_bins bins + depth = torch.log(depth / self.alpha) / np.log(self.beta / self.alpha) + depth = depth * (self.depth_bins - 1) + depth = torch.round(depth) + depth = depth.long() + return depth + + + + def _dequantize_depth(self, depth): + """ + Inverse of quantization + depth : NCHW -> N1HW + """ + # Get the center of the bin + + + + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + # assert torch.all(input <= 0), "Input should be negative" + + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + # assert torch.all(input)<=1) + if target.ndim == 3: + target = target.unsqueeze(1) + + target = self.quantize_depth(target) + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + # Set the mask to ignore_index + mask = mask.long() + input = input * mask + (1 - mask) * self.ignore_index + target = target * mask + (1 - mask) * self.ignore_index + + + + input = input.flatten(2) # N, nbins, H*W + target = target.flatten(1) # N, H*W + loss = self._loss_func(input, target) + + if not return_interpolated: + return loss + return loss, intr_input + + + + +def compute_scale_and_shift(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + a_00 = torch.sum(mask * prediction * prediction, (1, 2)) + a_01 = torch.sum(mask * prediction, (1, 2)) + a_11 = torch.sum(mask, (1, 2)) + + # right hand side: b = [b_0, b_1] + b_0 = torch.sum(mask * prediction * target, (1, 2)) + b_1 = torch.sum(mask * target, (1, 2)) + + # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b + x_0 = torch.zeros_like(b_0) + x_1 = torch.zeros_like(b_1) + + det = a_00 * a_11 - a_01 * a_01 + # A needs to be a positive definite matrix. + valid = det > 0 + + x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] + x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] + + return x_0, x_1 +class ScaleAndShiftInvariantLoss(nn.Module): + def __init__(self): + super().__init__() + self.name = "SSILoss" + + def forward(self, prediction, target, mask, interpolate=True, return_interpolated=False): + + if prediction.shape[-1] != target.shape[-1] and interpolate: + prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = prediction + else: + intr_input = prediction + + + prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze() + assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}." + + scale, shift = compute_scale_and_shift(prediction, target, mask) + + scaled_prediction = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) + + loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask]) + if not return_interpolated: + return loss + return loss, intr_input + + + + +if __name__ == '__main__': + # Tests for DiscreteNLLLoss + celoss = DiscreteNLLLoss() + print(celoss(torch.rand(4, 64, 26, 32)*10, torch.rand(4, 1, 26, 32)*10, )) + + d = torch.Tensor([6.59, 3.8, 10.0]) + print(celoss.dequantize_depth(celoss.quantize_depth(d))) diff --git a/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py b/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..d528ae126f1c51b2f25fd31f94a39591ceb2f43a --- /dev/null +++ b/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py @@ -0,0 +1,143 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +from zoedepth.trainers.loss import GradL1Loss, SILogLoss +from zoedepth.utils.config import DATASETS_CONFIG +from zoedepth.utils.misc import compute_metrics + +from .base_trainer import BaseTrainer + + +class Trainer(BaseTrainer): + def __init__(self, config, model, train_loader, test_loader=None, device=None): + super().__init__(config, model, train_loader, + test_loader=test_loader, device=device) + self.device = device + self.silog_loss = SILogLoss() + self.grad_loss = GradL1Loss() + self.domain_classifier_loss = nn.CrossEntropyLoss() + + self.scaler = amp.GradScaler(enabled=self.config.use_amp) + + def train_on_batch(self, batch, train_step): + """ + Expects a batch of images and depth as input + batch["image"].shape : batch_size, c, h, w + batch["depth"].shape : batch_size, 1, h, w + + Assumes all images in a batch are from the same dataset + """ + + images, depths_gt = batch['image'].to( + self.device), batch['depth'].to(self.device) + # batch['dataset'] is a tensor strings all valued either 'nyu' or 'kitti'. labels nyu -> 0, kitti -> 1 + dataset = batch['dataset'][0] + # Convert to 0s or 1s + domain_labels = torch.Tensor([dataset == 'kitti' for _ in range( + images.size(0))]).to(torch.long).to(self.device) + + # m = self.model.module if self.config.multigpu else self.model + + b, c, h, w = images.size() + mask = batch["mask"].to(self.device).to(torch.bool) + + losses = {} + + with amp.autocast(enabled=self.config.use_amp): + output = self.model(images) + pred_depths = output['metric_depth'] + domain_logits = output['domain_logits'] + + l_si, pred = self.silog_loss( + pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) + loss = self.config.w_si * l_si + losses[self.silog_loss.name] = l_si + + if self.config.w_grad > 0: + l_grad = self.grad_loss(pred, depths_gt, mask=mask) + loss = loss + self.config.w_grad * l_grad + losses[self.grad_loss.name] = l_grad + else: + l_grad = torch.Tensor([0]) + + if self.config.w_domain > 0: + l_domain = self.domain_classifier_loss( + domain_logits, domain_labels) + loss = loss + self.config.w_domain * l_domain + losses["DomainLoss"] = l_domain + else: + l_domain = torch.Tensor([0.]) + + self.scaler.scale(loss).backward() + + if self.config.clip_grad > 0: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.clip_grad) + + self.scaler.step(self.optimizer) + + if self.should_log and self.step > 1 and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + self.scaler.update() + self.optimizer.zero_grad(set_to_none=True) + + return losses + + def validate_on_batch(self, batch, val_step): + images = batch['image'].to(self.device) + depths_gt = batch['depth'].to(self.device) + dataset = batch['dataset'][0] + if 'has_valid_depth' in batch: + if not batch['has_valid_depth']: + return None, None + + depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) + with amp.autocast(enabled=self.config.use_amp): + m = self.model.module if self.config.multigpu else self.model + pred_depths = m(images)["metric_depth"] + pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) + + mask = torch.logical_and( + depths_gt > self.config.min_depth, depths_gt < self.config.max_depth) + with amp.autocast(enabled=self.config.use_amp): + l_depth = self.silog_loss( + pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) + + metrics = compute_metrics(depths_gt, pred_depths, **self.config) + losses = {f"{self.silog_loss.name}": l_depth.item()} + + if val_step == 1 and self.should_log: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + return metrics, losses diff --git a/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py b/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac1c24c0512c1c1b191670a7c24abb4fca47ba1 --- /dev/null +++ b/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py @@ -0,0 +1,177 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +from zoedepth.trainers.loss import GradL1Loss, SILogLoss +from zoedepth.utils.config import DATASETS_CONFIG +from zoedepth.utils.misc import compute_metrics +from zoedepth.data.preprocess import get_black_border + +from .base_trainer import BaseTrainer +from torchvision import transforms +from PIL import Image +import numpy as np + +class Trainer(BaseTrainer): + def __init__(self, config, model, train_loader, test_loader=None, device=None): + super().__init__(config, model, train_loader, + test_loader=test_loader, device=device) + self.device = device + self.silog_loss = SILogLoss() + self.grad_loss = GradL1Loss() + self.scaler = amp.GradScaler(enabled=self.config.use_amp) + + def train_on_batch(self, batch, train_step): + """ + Expects a batch of images and depth as input + batch["image"].shape : batch_size, c, h, w + batch["depth"].shape : batch_size, 1, h, w + """ + + images, depths_gt = batch['image'].to( + self.device), batch['depth'].to(self.device) + dataset = batch['dataset'][0] + + b, c, h, w = images.size() + mask = batch["mask"].to(self.device).to(torch.bool) + + losses = {} + + with amp.autocast(enabled=self.config.use_amp): + + output = self.model(images) + pred_depths = output['metric_depth'] + + l_si, pred = self.silog_loss( + pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) + loss = self.config.w_si * l_si + losses[self.silog_loss.name] = l_si + + if self.config.w_grad > 0: + l_grad = self.grad_loss(pred, depths_gt, mask=mask) + loss = loss + self.config.w_grad * l_grad + losses[self.grad_loss.name] = l_grad + else: + l_grad = torch.Tensor([0]) + + self.scaler.scale(loss).backward() + + if self.config.clip_grad > 0: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.clip_grad) + + self.scaler.step(self.optimizer) + + if self.should_log and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: + # -99 is treated as invalid depth in the log_images function and is colored grey. + depths_gt[torch.logical_not(mask)] = -99 + + self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + if self.config.get("log_rel", False): + self.log_images( + scalar_field={"RelPred": output["relative_depth"][0]}, prefix="TrainRel") + + self.scaler.update() + self.optimizer.zero_grad() + + return losses + + @torch.no_grad() + def eval_infer(self, x): + with amp.autocast(enabled=self.config.use_amp): + m = self.model.module if self.config.multigpu else self.model + pred_depths = m(x)['metric_depth'] + return pred_depths + + @torch.no_grad() + def crop_aware_infer(self, x): + # if we are not avoiding the black border, we can just use the normal inference + if not self.config.get("avoid_boundary", False): + return self.eval_infer(x) + + # otherwise, we need to crop the image to avoid the black border + # For now, this may be a bit slow due to converting to numpy and back + # We assume no normalization is done on the input image + + # get the black border + assert x.shape[0] == 1, "Only batch size 1 is supported for now" + x_pil = transforms.ToPILImage()(x[0].cpu()) + x_np = np.array(x_pil, dtype=np.uint8) + black_border_params = get_black_border(x_np) + top, bottom, left, right = black_border_params.top, black_border_params.bottom, black_border_params.left, black_border_params.right + x_np_cropped = x_np[top:bottom, left:right, :] + x_cropped = transforms.ToTensor()(Image.fromarray(x_np_cropped)) + + # run inference on the cropped image + pred_depths_cropped = self.eval_infer(x_cropped.unsqueeze(0).to(self.device)) + + # resize the prediction to x_np_cropped's size + pred_depths_cropped = nn.functional.interpolate( + pred_depths_cropped, size=(x_np_cropped.shape[0], x_np_cropped.shape[1]), mode="bilinear", align_corners=False) + + + # pad the prediction back to the original size + pred_depths = torch.zeros((1, 1, x_np.shape[0], x_np.shape[1]), device=pred_depths_cropped.device, dtype=pred_depths_cropped.dtype) + pred_depths[:, :, top:bottom, left:right] = pred_depths_cropped + + return pred_depths + + + + def validate_on_batch(self, batch, val_step): + images = batch['image'].to(self.device) + depths_gt = batch['depth'].to(self.device) + dataset = batch['dataset'][0] + mask = batch["mask"].to(self.device) + if 'has_valid_depth' in batch: + if not batch['has_valid_depth']: + return None, None + + depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) + mask = mask.squeeze().unsqueeze(0).unsqueeze(0) + if dataset == 'nyu': + pred_depths = self.crop_aware_infer(images) + else: + pred_depths = self.eval_infer(images) + pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) + + with amp.autocast(enabled=self.config.use_amp): + l_depth = self.silog_loss( + pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) + + metrics = compute_metrics(depths_gt, pred_depths, **self.config) + losses = {f"{self.silog_loss.name}": l_depth.item()} + + if val_step == 1 and self.should_log: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + return metrics, losses diff --git a/annotator/zoe/zoedepth/utils/__init__.py b/annotator/zoe/zoedepth/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/annotator/zoe/zoedepth/utils/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-310.pyc b/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e213c1559b6db13234e834170dac3c1cdbf09768 Binary files /dev/null and b/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-38.pyc b/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ad46a3bc3e4b21f48f3c3002a9f22bd68c41af8 Binary files /dev/null and b/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-39.pyc b/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..975722f74ac35f4a2851669e4d63acd17f2c37e5 Binary files /dev/null and b/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-310.pyc b/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd3dbc6159d6f06090d27b32f9fe5590656c45b6 Binary files /dev/null and b/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-310.pyc differ diff --git a/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-38.pyc b/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c079ae063796fc741b64ee4e602eb18852b70421 Binary files /dev/null and b/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-38.pyc differ diff --git a/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-39.pyc b/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61f06693e7d08205ce9135069fa0a2b109df637f Binary files /dev/null and b/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-39.pyc differ diff --git a/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-310.pyc b/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c5c28e3dd2b018ff0d0715e9468d8f318d2caed Binary files /dev/null and b/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-310.pyc differ diff --git a/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-38.pyc b/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d6d83cadc273986a772dd95259025aeab5cebd8 Binary files /dev/null and b/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-38.pyc differ diff --git a/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-39.pyc b/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4d86ed4812513147758fae4d2a7df83449cd950 Binary files /dev/null and b/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-39.pyc differ diff --git a/annotator/zoe/zoedepth/utils/arg_utils.py b/annotator/zoe/zoedepth/utils/arg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3004ec3679c0a40fd8961253733fb4343ad545 --- /dev/null +++ b/annotator/zoe/zoedepth/utils/arg_utils.py @@ -0,0 +1,33 @@ + + +def infer_type(x): # hacky way to infer type from string args + if not isinstance(x, str): + return x + + try: + x = int(x) + return x + except ValueError: + pass + + try: + x = float(x) + return x + except ValueError: + pass + + return x + + +def parse_unknown(unknown_args): + clean = [] + for a in unknown_args: + if "=" in a: + k, v = a.split("=") + clean.extend([k, v]) + else: + clean.append(a) + + keys = clean[::2] + values = clean[1::2] + return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)} diff --git a/annotator/zoe/zoedepth/utils/config.py b/annotator/zoe/zoedepth/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..84996564663dadf0e720de2a68ef8c53106ed666 --- /dev/null +++ b/annotator/zoe/zoedepth/utils/config.py @@ -0,0 +1,437 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import json +import os + +from .easydict import EasyDict as edict +from .arg_utils import infer_type + +import pathlib +import platform + +ROOT = pathlib.Path(__file__).parent.parent.resolve() + +HOME_DIR = os.path.expanduser("~") + +COMMON_CONFIG = { + "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"), + "project": "ZoeDepth", + "tags": '', + "notes": "", + "gpu": None, + "root": ".", + "uid": None, + "print_losses": False +} + +DATASETS_CONFIG = { + "kitti": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, # 704 + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "kitti_test": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": False, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "nyu": { + "dataset": "nyu", + "avoid_boundary": False, + "min_depth": 1e-3, # originally 0.1 + "max_depth": 10, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", + "input_height": 480, + "input_width": 640, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth_diff": -10, + "max_depth_diff": 10, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": False, + "garg_crop": False, + "eigen_crop": True + }, + "ibims": { + "dataset": "ibims", + "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "sunrgbd": { + "dataset": "sunrgbd", + "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 8, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_indoor": { + "dataset": "diml_indoor", + "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_outdoor": { + "dataset": "diml_outdoor", + "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 2, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "diode_indoor": { + "dataset": "diode_indoor", + "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diode_outdoor": { + "dataset": "diode_outdoor", + "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "hypersim_test": { + "dataset": "hypersim_test", + "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 10 + }, + "vkitti": { + "dataset": "vkitti", + "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "vkitti2": { + "dataset": "vkitti2", + "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, + "ddad": { + "dataset": "ddad", + "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, +} + +ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] +ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] +ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR + +COMMON_TRAINING_CONFIG = { + "dataset": "nyu", + "distributed": True, + "workers": 16, + "clip_grad": 0.1, + "use_shared_dict": False, + "shared_dict": None, + "use_amp": False, + + "aug": True, + "random_crop": False, + "random_translate": False, + "translate_prob": 0.2, + "max_translation": 100, + + "validate_every": 0.25, + "log_images_every": 0.1, + "prefetch": False, +} + + +def flatten(config, except_keys=('bin_conf')): + def recurse(inp): + if isinstance(inp, dict): + for key, value in inp.items(): + if key in except_keys: + yield (key, value) + if isinstance(value, dict): + yield from recurse(value) + else: + yield (key, value) + + return dict(list(recurse(config))) + + +def split_combined_args(kwargs): + """Splits the arguments that are combined with '__' into multiple arguments. + Combined arguments should have equal number of keys and values. + Keys are separated by '__' and Values are separated with ';'. + For example, '__n_bins__lr=256;0.001' + + Args: + kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. + + Returns: + dict: Parsed dict with the combined arguments split into individual key-value pairs. + """ + new_kwargs = dict(kwargs) + for key, value in kwargs.items(): + if key.startswith("__"): + keys = key.split("__")[1:] + values = value.split(";") + assert len(keys) == len( + values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" + for k, v in zip(keys, values): + new_kwargs[k] = v + return new_kwargs + + +def parse_list(config, key, dtype=int): + """Parse a list of values for the key if the value is a string. The values are separated by a comma. + Modifies the config in place. + """ + if key in config: + if isinstance(config[key], str): + config[key] = list(map(dtype, config[key].split(','))) + assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]] + ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." + + +def get_model_config(model_name, model_version=None): + """Find and parse the .json config file for the model. + + Args: + model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. + model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. + + Returns: + easydict: the config dictionary for the model. + """ + config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" + config_file = os.path.join(ROOT, "models", model_name, config_fname) + if not os.path.exists(config_file): + return None + + with open(config_file, "r") as f: + config = edict(json.load(f)) + + # handle dictionary inheritance + # only training config is supported for inheritance + if "inherit" in config.train and config.train.inherit is not None: + inherit_config = get_model_config(config.train["inherit"]).train + for key, value in inherit_config.items(): + if key not in config.train: + config.train[key] = value + return edict(config) + + +def update_model_config(config, mode, model_name, model_version=None, strict=False): + model_config = get_model_config(model_name, model_version) + if model_config is not None: + config = {**config, ** + flatten({**model_config.model, **model_config[mode]})} + elif strict: + raise ValueError(f"Config file for model {model_name} not found.") + return config + + +def check_choices(name, value, choices): + # return # No checks in dev branch + if value not in choices: + raise ValueError(f"{name} {value} not in supported choices {choices}") + + +KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", + "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1 + + +def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): + """Main entry point to get the config for the model. + + Args: + model_name (str): name of the desired model. + mode (str, optional): "train" or "infer". Defaults to 'train'. + dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. + + Keyword Args: key-value pairs of arguments to overwrite the default config. + + The order of precedence for overwriting the config is (Higher precedence first): + # 1. overwrite_kwargs + # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json + # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json + # 4. common_config: Default config for all models specified in COMMON_CONFIG + + Returns: + easydict: The config dictionary for the model. + """ + + + check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) + check_choices("Mode", mode, ["train", "infer", "eval"]) + if mode == "train": + check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) + + config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) + config = update_model_config(config, mode, model_name) + + # update with model version specific config + version_name = overwrite_kwargs.get("version_name", config["version_name"]) + config = update_model_config(config, mode, model_name, version_name) + + # update with config version if specified + config_version = overwrite_kwargs.get("config_version", None) + if config_version is not None: + print("Overwriting config with config_version", config_version) + config = update_model_config(config, mode, model_name, config_version) + + # update with overwrite_kwargs + # Combined args are useful for hyperparameter search + overwrite_kwargs = split_combined_args(overwrite_kwargs) + config = {**config, **overwrite_kwargs} + + # Casting to bool # TODO: Not necessary. Remove and test + for key in KEYS_TYPE_BOOL: + if key in config: + config[key] = bool(config[key]) + + # Model specific post processing of config + parse_list(config, "n_attractors") + + # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs + if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: + bin_conf = config['bin_conf'] # list of dicts + n_bins = overwrite_kwargs['n_bins'] + new_bin_conf = [] + for conf in bin_conf: + conf['n_bins'] = n_bins + new_bin_conf.append(conf) + config['bin_conf'] = new_bin_conf + + if mode == "train": + orig_dataset = dataset + if dataset == "mix": + dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader + if dataset is not None: + config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb + + if dataset is not None: + config['dataset'] = dataset + config = {**DATASETS_CONFIG[dataset], **config} + + + config['model'] = model_name + typed_config = {k: infer_type(v) for k, v in config.items()} + # add hostname to config + config['hostname'] = platform.node() + return edict(typed_config) + + +def change_dataset(config, new_dataset): + config.update(DATASETS_CONFIG[new_dataset]) + return config diff --git a/annotator/zoe/zoedepth/utils/easydict/__init__.py b/annotator/zoe/zoedepth/utils/easydict/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15928179b0182c6045d98bc0a7be1c6ca45f675e --- /dev/null +++ b/annotator/zoe/zoedepth/utils/easydict/__init__.py @@ -0,0 +1,158 @@ +""" +EasyDict +Copy/pasted from https://github.com/makinacorpus/easydict +Original author: Mathieu Leplatre +""" + +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + >>> EasyDict((('a', 1), ('b', 2))) + {'a': 1, 'b': 2} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> list(map(attrgetter('x'), d.bar)) + [1, 3] + >>> list(map(attrgetter('y'), d.bar)) + [2, 4] + >>> d = EasyDict() + >>> list(d.keys()) + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> list(o.items()) + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + else: + d = dict(d) + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest + doctest.testmod() \ No newline at end of file diff --git a/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-310.pyc b/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caeb1d02f6838dc7100623f7667ff6af87427c4f Binary files /dev/null and b/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-310.pyc differ diff --git a/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-38.pyc b/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c39d05fd3db949134e5f67d9bc3a6bfbca2ffb3e Binary files /dev/null and b/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-38.pyc differ diff --git a/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-39.pyc b/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb3bf3cc7c5ec62590838738beeacf6eed148afe Binary files /dev/null and b/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-39.pyc differ diff --git a/annotator/zoe/zoedepth/utils/geometry.py b/annotator/zoe/zoedepth/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..e3da8c75b5a8e39b4b58a4dcd827b84d79b9115c --- /dev/null +++ b/annotator/zoe/zoedepth/utils/geometry.py @@ -0,0 +1,98 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np + +def get_intrinsics(H,W): + """ + Intrinsics for a pinhole camera model. + Assume fov of 55 degrees and central principal point. + """ + f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0) + cx = 0.5 * W + cy = 0.5 * H + return np.array([[f, 0, cx], + [0, f, cy], + [0, 0, 1]]) + +def depth_to_points(depth, R=None, t=None): + + K = get_intrinsics(depth.shape[1], depth.shape[2]) + Kinv = np.linalg.inv(K) + if R is None: + R = np.eye(3) + if t is None: + t = np.zeros(3) + + # M converts from your coordinate to PyTorch3D's coordinate system + M = np.eye(3) + M[0, 0] = -1.0 + M[1, 1] = -1.0 + + height, width = depth.shape[1:3] + + x = np.arange(width) + y = np.arange(height) + coord = np.stack(np.meshgrid(x, y), -1) + coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1 + coord = coord.astype(np.float32) + # coord = torch.as_tensor(coord, dtype=torch.float32, device=device) + coord = coord[None] # bs, h, w, 3 + + D = depth[:, :, :, None, None] + # print(D.shape, Kinv[None, None, None, ...].shape, coord[:, :, :, :, None].shape ) + pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None] + # pts3D_1 live in your coordinate system. Convert them to Py3D's + pts3D_1 = M[None, None, None, ...] @ pts3D_1 + # from reference to targe tviewpoint + pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None] + # pts3D_2 = pts3D_1 + # depth_2 = pts3D_2[:, :, :, 2, :] # b,1,h,w + return pts3D_2[:, :, :, :3, 0][0] + + +def create_triangles(h, w, mask=None): + """ + Reference: https://github.com/google-research/google-research/blob/e96197de06613f1b027d20328e06d69829fa5a89/infinite_nature/render_utils.py#L68 + Creates mesh triangle indices from a given pixel grid size. + This function is not and need not be differentiable as triangle indices are + fixed. + Args: + h: (int) denoting the height of the image. + w: (int) denoting the width of the image. + Returns: + triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3) + """ + x, y = np.meshgrid(range(w - 1), range(h - 1)) + tl = y * w + x + tr = y * w + x + 1 + bl = (y + 1) * w + x + br = (y + 1) * w + x + 1 + triangles = np.array([tl, bl, tr, br, tr, bl]) + triangles = np.transpose(triangles, (1, 2, 0)).reshape( + ((w - 1) * (h - 1) * 2, 3)) + if mask is not None: + mask = mask.reshape(-1) + triangles = triangles[mask[triangles].all(1)] + return triangles diff --git a/annotator/zoe/zoedepth/utils/misc.py b/annotator/zoe/zoedepth/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbe403d3669829eecdf658458c76aa5e87e2b33 --- /dev/null +++ b/annotator/zoe/zoedepth/utils/misc.py @@ -0,0 +1,368 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +"""Miscellaneous utility functions.""" + +from scipy import ndimage + +import base64 +import math +import re +from io import BytesIO + +import matplotlib +import matplotlib.cm +import numpy as np +import requests +import torch +import torch.distributed as dist +import torch.nn +import torch.nn as nn +import torch.utils.data.distributed +from PIL import Image +from torchvision.transforms import ToTensor + + +class RunningAverage: + def __init__(self): + self.avg = 0 + self.count = 0 + + def append(self, value): + self.avg = (value + self.count * self.avg) / (self.count + 1) + self.count += 1 + + def get_value(self): + return self.avg + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + + +class RunningAverageDict: + """A dictionary of running averages.""" + def __init__(self): + self._dict = None + + def update(self, new_dict): + if new_dict is None: + return + + if self._dict is None: + self._dict = dict() + for key, value in new_dict.items(): + self._dict[key] = RunningAverage() + + for key, value in new_dict.items(): + self._dict[key].append(value) + + def get_value(self): + if self._dict is None: + return None + return {key: value.get_value() for key, value in self._dict.items()} + + +def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None): + """Converts a depth map to a color image. + + Args: + value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed + vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None. + vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None. + cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'. + invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99. + invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None. + background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255). + gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False. + value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None. + + Returns: + numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4) + """ + if isinstance(value, torch.Tensor): + value = value.detach().cpu().numpy() + + value = value.squeeze() + if invalid_mask is None: + invalid_mask = value == invalid_val + mask = np.logical_not(invalid_mask) + + # normalize + vmin = np.percentile(value[mask],2) if vmin is None else vmin + vmax = np.percentile(value[mask],85) if vmax is None else vmax + if vmin != vmax: + value = (value - vmin) / (vmax - vmin) # vmin..vmax + else: + # Avoid 0-division + value = value * 0. + + # squeeze last dim if it exists + # grey out the invalid values + + value[invalid_mask] = np.nan + cmapper = matplotlib.cm.get_cmap(cmap) + if value_transform: + value = value_transform(value) + # value = value / value.max() + value = cmapper(value, bytes=True) # (nxmx4) + + # img = value[:, :, :] + img = value[...] + img[invalid_mask] = background_color + + # return img.transpose((2, 0, 1)) + if gamma_corrected: + # gamma correction + img = img / 255 + img = np.power(img, 2.2) + img = img * 255 + img = img.astype(np.uint8) + return img + + +def count_parameters(model, include_all=False): + return sum(p.numel() for p in model.parameters() if p.requires_grad or include_all) + + +def compute_errors(gt, pred): + """Compute metrics for 'pred' compared to 'gt' + + Args: + gt (numpy.ndarray): Ground truth values + pred (numpy.ndarray): Predicted values + + gt.shape should be equal to pred.shape + + Returns: + dict: Dictionary containing the following metrics: + 'a1': Delta1 accuracy: Fraction of pixels that are within a scale factor of 1.25 + 'a2': Delta2 accuracy: Fraction of pixels that are within a scale factor of 1.25^2 + 'a3': Delta3 accuracy: Fraction of pixels that are within a scale factor of 1.25^3 + 'abs_rel': Absolute relative error + 'rmse': Root mean squared error + 'log_10': Absolute log10 error + 'sq_rel': Squared relative error + 'rmse_log': Root mean squared error on the log scale + 'silog': Scale invariant log error + """ + thresh = np.maximum((gt / pred), (pred / gt)) + a1 = (thresh < 1.25).mean() + a2 = (thresh < 1.25 ** 2).mean() + a3 = (thresh < 1.25 ** 3).mean() + + abs_rel = np.mean(np.abs(gt - pred) / gt) + sq_rel = np.mean(((gt - pred) ** 2) / gt) + + rmse = (gt - pred) ** 2 + rmse = np.sqrt(rmse.mean()) + + rmse_log = (np.log(gt) - np.log(pred)) ** 2 + rmse_log = np.sqrt(rmse_log.mean()) + + err = np.log(pred) - np.log(gt) + silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 + + log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean() + return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log, + silog=silog, sq_rel=sq_rel) + + +def compute_metrics(gt, pred, interpolate=True, garg_crop=False, eigen_crop=True, dataset='nyu', min_depth_eval=0.1, max_depth_eval=10, **kwargs): + """Compute metrics of predicted depth maps. Applies cropping and masking as necessary or specified via arguments. Refer to compute_errors for more details on metrics. + """ + if 'config' in kwargs: + config = kwargs['config'] + garg_crop = config.garg_crop + eigen_crop = config.eigen_crop + min_depth_eval = config.min_depth_eval + max_depth_eval = config.max_depth_eval + + if gt.shape[-2:] != pred.shape[-2:] and interpolate: + pred = nn.functional.interpolate( + pred, gt.shape[-2:], mode='bilinear', align_corners=True) + + pred = pred.squeeze().cpu().numpy() + pred[pred < min_depth_eval] = min_depth_eval + pred[pred > max_depth_eval] = max_depth_eval + pred[np.isinf(pred)] = max_depth_eval + pred[np.isnan(pred)] = min_depth_eval + + gt_depth = gt.squeeze().cpu().numpy() + valid_mask = np.logical_and( + gt_depth > min_depth_eval, gt_depth < max_depth_eval) + + if garg_crop or eigen_crop: + gt_height, gt_width = gt_depth.shape + eval_mask = np.zeros(valid_mask.shape) + + if garg_crop: + eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), + int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1 + + elif eigen_crop: + # print("-"*10, " EIGEN CROP ", "-"*10) + if dataset == 'kitti': + eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), + int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1 + else: + # assert gt_depth.shape == (480, 640), "Error: Eigen crop is currently only valid for (480, 640) images" + eval_mask[45:471, 41:601] = 1 + else: + eval_mask = np.ones(valid_mask.shape) + valid_mask = np.logical_and(valid_mask, eval_mask) + return compute_errors(gt_depth[valid_mask], pred[valid_mask]) + + +#################################### Model uilts ################################################ + + +def parallelize(config, model, find_unused_parameters=True): + + if config.gpu is not None: + torch.cuda.set_device(config.gpu) + model = model.cuda(config.gpu) + + config.multigpu = False + if config.distributed: + # Use DDP + config.multigpu = True + config.rank = config.rank * config.ngpus_per_node + config.gpu + dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url, + world_size=config.world_size, rank=config.rank) + config.batch_size = int(config.batch_size / config.ngpus_per_node) + # config.batch_size = 8 + config.workers = int( + (config.num_workers + config.ngpus_per_node - 1) / config.ngpus_per_node) + print("Device", config.gpu, "Rank", config.rank, "batch size", + config.batch_size, "Workers", config.workers) + torch.cuda.set_device(config.gpu) + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = model.cuda(config.gpu) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu], output_device=config.gpu, + find_unused_parameters=find_unused_parameters) + + elif config.gpu is None: + # Use DP + config.multigpu = True + model = model.cuda() + model = torch.nn.DataParallel(model) + + return model + + +################################################################################################# + + +##################################################################################################### + + +class colors: + '''Colors class: + Reset all colors with colors.reset + Two subclasses fg for foreground and bg for background. + Use as colors.subclass.colorname. + i.e. colors.fg.red or colors.bg.green + Also, the generic bold, disable, underline, reverse, strikethrough, + and invisible work with the main class + i.e. colors.bold + ''' + reset = '\033[0m' + bold = '\033[01m' + disable = '\033[02m' + underline = '\033[04m' + reverse = '\033[07m' + strikethrough = '\033[09m' + invisible = '\033[08m' + + class fg: + black = '\033[30m' + red = '\033[31m' + green = '\033[32m' + orange = '\033[33m' + blue = '\033[34m' + purple = '\033[35m' + cyan = '\033[36m' + lightgrey = '\033[37m' + darkgrey = '\033[90m' + lightred = '\033[91m' + lightgreen = '\033[92m' + yellow = '\033[93m' + lightblue = '\033[94m' + pink = '\033[95m' + lightcyan = '\033[96m' + + class bg: + black = '\033[40m' + red = '\033[41m' + green = '\033[42m' + orange = '\033[43m' + blue = '\033[44m' + purple = '\033[45m' + cyan = '\033[46m' + lightgrey = '\033[47m' + + +def printc(text, color): + print(f"{color}{text}{colors.reset}") + +############################################ + +def get_image_from_url(url): + response = requests.get(url) + img = Image.open(BytesIO(response.content)).convert("RGB") + return img + +def url_to_torch(url, size=(384, 384)): + img = get_image_from_url(url) + img = img.resize(size, Image.ANTIALIAS) + img = torch.from_numpy(np.asarray(img)).float() + img = img.permute(2, 0, 1) + img.div_(255) + return img + +def pil_to_batched_tensor(img): + return ToTensor()(img).unsqueeze(0) + +def save_raw_16bit(depth, fpath="raw.png"): + if isinstance(depth, torch.Tensor): + depth = depth.squeeze().cpu().numpy() + + assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array" + assert depth.ndim == 2, "Depth must be 2D" + depth = depth * 256 # scale for 16-bit png + depth = depth.astype(np.uint16) + depth = Image.fromarray(depth) + depth.save(fpath) + print("Saved raw depth to", fpath) \ No newline at end of file diff --git a/ckpt/download.sh b/ckpt/download.sh new file mode 100644 index 0000000000000000000000000000000000000000..ee5de97b2370efacd873e972b1823140489db73d --- /dev/null +++ b/ckpt/download.sh @@ -0,0 +1,3 @@ +git lfs install +git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5 +git clone https://huggingface.co/lllyasviel/control_v11p_sd15_openpose \ No newline at end of file diff --git a/config/.gitignore b/config/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f7d13a2d1682680894eefbcc37956fc3f57a65ba --- /dev/null +++ b/config/.gitignore @@ -0,0 +1,2 @@ +# debug/** +trash/* \ No newline at end of file diff --git a/config/run_two_man.yaml b/config/run_two_man.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f34d592c3a2cdf6278414ff46ef62540c809096a --- /dev/null +++ b/config/run_two_man.yaml @@ -0,0 +1,49 @@ +pretrained_model_path: /home/xianyang/Data/code/FateZero/ckpt/stable-diffusion-v1-5 +# "./ckpt/stable-diffusion-v1-5" + +logdir: ./result/run_two_man_test_2nd_time/ + +dataset_config: + path: "data/run_two_man/run_two_man" + prompt: Man in red hoddie and man in gray shirt are jogging in forest + n_sample_frame: 16 + start_sample_frame: 0 + sampling_rate: 1 + layout_mask_dir: "./data/run_two_man/layout_masks" + layout_mask_order: ['left_man','right_man','trees'] + negative_promot: "ugly, blurry, low res, unrealistic, unaesthetic" + +control_config: + control_type: "dwpose" + pretrained_controlnet_path: /home/xianyang/Data/code/FateZero/ckpt/control_v11p_sd15_openpose + #"./ckpt/control_v11p_sd15_openpose" + controlnet_conditioning_scale: 1.0 + hand: True + face: False + +editing_config: + use_invertion_latents: true + inject_step: 0 + old_qk: 1 + flatten_res: [1] + guidance_scale: 7.5 + use_pnp: false + use_freeu: false + editing_prompts: [ + ['Iron Man and Spiderman are jogging under cherry blossoms','Iron Man','Spiderman','cherry blossoms'], + ] + + clip_length: "${..dataset_config.n_sample_frame}" + sample_seeds: [0] + num_inference_steps: 50 + blending_percentage: 0 + + +test_pipeline_config: + target: video_diffusion.pipelines.ddim_spatial_temporal.DDIMSpatioTemporalStableDiffusionPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + + + +seed: 42 + diff --git a/data/run_two_man/layout_masks/left_man/00000.png b/data/run_two_man/layout_masks/left_man/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..5e9428add82a4e9fbbc35f1be2f5ebbe00375e21 Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00000.png differ diff --git a/data/run_two_man/layout_masks/left_man/00001.png b/data/run_two_man/layout_masks/left_man/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..650b6592b55b800279166e5d24f6d491100a36e1 Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00001.png differ diff --git a/data/run_two_man/layout_masks/left_man/00002.png b/data/run_two_man/layout_masks/left_man/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..9f9085a0ce00a2095d74661846060f2d1a049331 Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00002.png differ diff --git a/data/run_two_man/layout_masks/left_man/00003.png b/data/run_two_man/layout_masks/left_man/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..b1f85db348ea7ae297b0817ae65f653310bd70d7 Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00003.png differ diff --git a/data/run_two_man/layout_masks/left_man/00004.png b/data/run_two_man/layout_masks/left_man/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..a4f1e38a485098965b5d77cfb82724518abefe65 Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00004.png differ diff --git a/data/run_two_man/layout_masks/left_man/00005.png b/data/run_two_man/layout_masks/left_man/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..30791392d338410991784f60828e19403a106715 Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00005.png differ diff --git a/data/run_two_man/layout_masks/left_man/00006.png b/data/run_two_man/layout_masks/left_man/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..fb154b7330b36992b5b1d85a697687c1a60bedee Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00006.png differ diff --git a/data/run_two_man/layout_masks/left_man/00007.png b/data/run_two_man/layout_masks/left_man/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..66683d42becef883ecdcf3e6f945d7609bbea90e Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00007.png differ diff --git a/data/run_two_man/layout_masks/left_man/00008.png b/data/run_two_man/layout_masks/left_man/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..d90f61449ce548a8addcdf7a92c344e7f82a42a4 Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00008.png differ diff --git a/data/run_two_man/layout_masks/left_man/00009.png b/data/run_two_man/layout_masks/left_man/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..7662dc05fe99d15830d001510b22420148819a4e Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00009.png differ diff --git a/data/run_two_man/layout_masks/left_man/00010.png b/data/run_two_man/layout_masks/left_man/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..e0d5253783b0b91dfaadb6715db78ce7c54a682b Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00010.png differ diff --git a/data/run_two_man/layout_masks/left_man/00011.png b/data/run_two_man/layout_masks/left_man/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..a41126a9398f8729d422d2a87703a55b45d660f1 Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00011.png differ diff --git a/data/run_two_man/layout_masks/left_man/00012.png b/data/run_two_man/layout_masks/left_man/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..73908829037fc5172c15558e349177fa5bab0e26 Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00012.png differ diff --git a/data/run_two_man/layout_masks/left_man/00013.png b/data/run_two_man/layout_masks/left_man/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..96f727e9fe95f412e0a10c13c995bf32f5ea8d8b Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00013.png differ diff --git a/data/run_two_man/layout_masks/left_man/00014.png b/data/run_two_man/layout_masks/left_man/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..c894a89d714f926559bcc3d1d6478b89e057e8ea Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00014.png differ diff --git a/data/run_two_man/layout_masks/left_man/00015.png b/data/run_two_man/layout_masks/left_man/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..73c9ecaabcebd9b6714ac288297f444514c7e03c Binary files /dev/null and b/data/run_two_man/layout_masks/left_man/00015.png differ diff --git a/data/run_two_man/layout_masks/right_man/00000.png b/data/run_two_man/layout_masks/right_man/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..1bd1719bfadc363ad4d5318928aa2e0c3c119267 Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00000.png differ diff --git a/data/run_two_man/layout_masks/right_man/00001.png b/data/run_two_man/layout_masks/right_man/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..1cb272fec2c7dcc7a4a5cd04cda002fb1d3ea7ae Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00001.png differ diff --git a/data/run_two_man/layout_masks/right_man/00002.png b/data/run_two_man/layout_masks/right_man/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..42cb39f3200d1bf06c76cfcceff0206ee5ab35b7 Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00002.png differ diff --git a/data/run_two_man/layout_masks/right_man/00003.png b/data/run_two_man/layout_masks/right_man/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..b71024bd3702c52d67d5807a0695badba5719821 Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00003.png differ diff --git a/data/run_two_man/layout_masks/right_man/00004.png b/data/run_two_man/layout_masks/right_man/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..96102b27c00c0ef55ea2e485490700e24277485b Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00004.png differ diff --git a/data/run_two_man/layout_masks/right_man/00005.png b/data/run_two_man/layout_masks/right_man/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..8eb26f584ae1d0c5e6c7c8bb4de50207961d3b60 Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00005.png differ diff --git a/data/run_two_man/layout_masks/right_man/00006.png b/data/run_two_man/layout_masks/right_man/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..c375fca69a54c7fcc82fa242f2f334b76a5a68f6 Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00006.png differ diff --git a/data/run_two_man/layout_masks/right_man/00007.png b/data/run_two_man/layout_masks/right_man/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..81be191e41ed2b18de7c7a62e14aea1082c27ebd Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00007.png differ diff --git a/data/run_two_man/layout_masks/right_man/00008.png b/data/run_two_man/layout_masks/right_man/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..7866551011767e19624365623c7da0ec5aa987b6 Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00008.png differ diff --git a/data/run_two_man/layout_masks/right_man/00009.png b/data/run_two_man/layout_masks/right_man/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..1d7b5cb0f97a11abffb1d2fd884fd6015f03ab47 Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00009.png differ diff --git a/data/run_two_man/layout_masks/right_man/00010.png b/data/run_two_man/layout_masks/right_man/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..df1b32764f9e2fbcfd64681792058e99628615cf Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00010.png differ diff --git a/data/run_two_man/layout_masks/right_man/00011.png b/data/run_two_man/layout_masks/right_man/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..550ce0c497a87dcea1e4ea57ad0062adc7ccbe96 Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00011.png differ diff --git a/data/run_two_man/layout_masks/right_man/00012.png b/data/run_two_man/layout_masks/right_man/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..83d63f55e97e41bf9186d55c03ad67fd1db4503d Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00012.png differ diff --git a/data/run_two_man/layout_masks/right_man/00013.png b/data/run_two_man/layout_masks/right_man/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..6196ee34eaa5980754cefba7e1bfe06c4f9b7d7c Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00013.png differ diff --git a/data/run_two_man/layout_masks/right_man/00014.png b/data/run_two_man/layout_masks/right_man/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..f7b8dfd6db33f53077249eb5b50492fd5195a622 Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00014.png differ diff --git a/data/run_two_man/layout_masks/right_man/00015.png b/data/run_two_man/layout_masks/right_man/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..f2ef70c161600f8f4ac85705798127534a35a99b Binary files /dev/null and b/data/run_two_man/layout_masks/right_man/00015.png differ diff --git a/data/run_two_man/layout_masks/trees/00000.png b/data/run_two_man/layout_masks/trees/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..080bf1f951969c9c7d643db4060746ba08a210ac Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00000.png differ diff --git a/data/run_two_man/layout_masks/trees/00001.png b/data/run_two_man/layout_masks/trees/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..2fb7ada430ab609acfd725f4eb5ee37b66d23a3e Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00001.png differ diff --git a/data/run_two_man/layout_masks/trees/00002.png b/data/run_two_man/layout_masks/trees/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..bad3d048ae8fb331b8be7843aac32aa79c319178 Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00002.png differ diff --git a/data/run_two_man/layout_masks/trees/00003.png b/data/run_two_man/layout_masks/trees/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..dcadbbfed86761f9247370da85ae749857e6cb77 Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00003.png differ diff --git a/data/run_two_man/layout_masks/trees/00004.png b/data/run_two_man/layout_masks/trees/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..00d9ef9b38600b395485a2e4246edb15342afd9c Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00004.png differ diff --git a/data/run_two_man/layout_masks/trees/00005.png b/data/run_two_man/layout_masks/trees/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..d4b9b8e6a11546f1f66aca3ddc3e7759c2088942 Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00005.png differ diff --git a/data/run_two_man/layout_masks/trees/00006.png b/data/run_two_man/layout_masks/trees/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..703af9491c51ea707310dcc3701b7271150caad5 Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00006.png differ diff --git a/data/run_two_man/layout_masks/trees/00007.png b/data/run_two_man/layout_masks/trees/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..3dff36b33deaed614587bf9a8a1012a35a97aac2 Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00007.png differ diff --git a/data/run_two_man/layout_masks/trees/00008.png b/data/run_two_man/layout_masks/trees/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..ad6e7bcc4c093a8ceb696e21cfb15e6fa1c4c4ed Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00008.png differ diff --git a/data/run_two_man/layout_masks/trees/00009.png b/data/run_two_man/layout_masks/trees/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..513cfd4e7c53f9024b21b38deabaeb378d5c9c14 Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00009.png differ diff --git a/data/run_two_man/layout_masks/trees/00010.png b/data/run_two_man/layout_masks/trees/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..9403d3648126173382d6e775ddea59f6130f8f2e Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00010.png differ diff --git a/data/run_two_man/layout_masks/trees/00011.png b/data/run_two_man/layout_masks/trees/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..c7e7d4c4a87f81d9c25890f712d64059d93db7ed Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00011.png differ diff --git a/data/run_two_man/layout_masks/trees/00012.png b/data/run_two_man/layout_masks/trees/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..acec5b0ea317293137b10d43e36970739dc9c8ec Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00012.png differ diff --git a/data/run_two_man/layout_masks/trees/00013.png b/data/run_two_man/layout_masks/trees/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..1fbf869cd422a55e6719b16a3db286ea8aaf0530 Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00013.png differ diff --git a/data/run_two_man/layout_masks/trees/00014.png b/data/run_two_man/layout_masks/trees/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..b6ce73f2ec22cc56dfb27e98796f18d8fd708b17 Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00014.png differ diff --git a/data/run_two_man/layout_masks/trees/00015.png b/data/run_two_man/layout_masks/trees/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..138455e698da43101b4f13326903d59353596f41 Binary files /dev/null and b/data/run_two_man/layout_masks/trees/00015.png differ diff --git a/data/run_two_man/run_two_man/00000.png b/data/run_two_man/run_two_man/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..03bccb2480510b91297dedc0e2e91cec3b3291f5 Binary files /dev/null and b/data/run_two_man/run_two_man/00000.png differ diff --git a/data/run_two_man/run_two_man/00001.png b/data/run_two_man/run_two_man/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..c10aa102374abe1c1fe75e2249778e65ab0cbf70 Binary files /dev/null and b/data/run_two_man/run_two_man/00001.png differ diff --git a/data/run_two_man/run_two_man/00002.png b/data/run_two_man/run_two_man/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..7d4c9cba768f325bb2e2ad88e7e9be37ef7f49c6 Binary files /dev/null and b/data/run_two_man/run_two_man/00002.png differ diff --git a/data/run_two_man/run_two_man/00003.png b/data/run_two_man/run_two_man/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..a4495295838a23771f74415d6f325240137d855e Binary files /dev/null and b/data/run_two_man/run_two_man/00003.png differ diff --git a/data/run_two_man/run_two_man/00004.png b/data/run_two_man/run_two_man/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..08ed2100fd3fcee9ecef7fa684f607db3b3c7880 Binary files /dev/null and b/data/run_two_man/run_two_man/00004.png differ diff --git a/data/run_two_man/run_two_man/00005.png b/data/run_two_man/run_two_man/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..2e931439972da4a22ae16f39e7ef0af4486e3156 Binary files /dev/null and b/data/run_two_man/run_two_man/00005.png differ diff --git a/data/run_two_man/run_two_man/00006.png b/data/run_two_man/run_two_man/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..256881d98ba1fc086a3a41cd5f9a8fd2560380fc Binary files /dev/null and b/data/run_two_man/run_two_man/00006.png differ diff --git a/data/run_two_man/run_two_man/00007.png b/data/run_two_man/run_two_man/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..f1f75688a3eadb2f75c8b50a4a2dc22d57c65f8f Binary files /dev/null and b/data/run_two_man/run_two_man/00007.png differ diff --git a/data/run_two_man/run_two_man/00008.png b/data/run_two_man/run_two_man/00008.png new file mode 100644 index 0000000000000000000000000000000000000000..dcf501f505d1ce4131701f3313dc9f91e2d9aabb Binary files /dev/null and b/data/run_two_man/run_two_man/00008.png differ diff --git a/data/run_two_man/run_two_man/00009.png b/data/run_two_man/run_two_man/00009.png new file mode 100644 index 0000000000000000000000000000000000000000..ac35c965e1fd43c957268c151c4ecadaf5b92b3e Binary files /dev/null and b/data/run_two_man/run_two_man/00009.png differ diff --git a/data/run_two_man/run_two_man/00010.png b/data/run_two_man/run_two_man/00010.png new file mode 100644 index 0000000000000000000000000000000000000000..f66d1615ddb412e6332bd15220e8ed4599a29a06 Binary files /dev/null and b/data/run_two_man/run_two_man/00010.png differ diff --git a/data/run_two_man/run_two_man/00011.png b/data/run_two_man/run_two_man/00011.png new file mode 100644 index 0000000000000000000000000000000000000000..7a770e9ac4511416b333d7c41dbe1cc605f909b9 Binary files /dev/null and b/data/run_two_man/run_two_man/00011.png differ diff --git a/data/run_two_man/run_two_man/00012.png b/data/run_two_man/run_two_man/00012.png new file mode 100644 index 0000000000000000000000000000000000000000..c00f1e77bcbcebe7105f4d3c18d5c9cf701bc0c3 Binary files /dev/null and b/data/run_two_man/run_two_man/00012.png differ diff --git a/data/run_two_man/run_two_man/00013.png b/data/run_two_man/run_two_man/00013.png new file mode 100644 index 0000000000000000000000000000000000000000..9242f34a9205e8d60f13c10bc32e87f59ec8f0e0 Binary files /dev/null and b/data/run_two_man/run_two_man/00013.png differ diff --git a/data/run_two_man/run_two_man/00014.png b/data/run_two_man/run_two_man/00014.png new file mode 100644 index 0000000000000000000000000000000000000000..4b788a8d49d4f655077c7e7a7a4be4736aef3168 Binary files /dev/null and b/data/run_two_man/run_two_man/00014.png differ diff --git a/data/run_two_man/run_two_man/00015.png b/data/run_two_man/run_two_man/00015.png new file mode 100644 index 0000000000000000000000000000000000000000..f95811378a9c2c26a94f00dc0c5c5e0701d63a94 Binary files /dev/null and b/data/run_two_man/run_two_man/00015.png differ diff --git a/download_all.sh b/download_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..2d8d979762351751c4e0510a4d84552a955ba513 --- /dev/null +++ b/download_all.sh @@ -0,0 +1,2 @@ +cd ./ckpt +bash download.sh diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72122ddd3d7c6223b2d3fcf4ef3be368798b744c --- /dev/null +++ b/environment.yaml @@ -0,0 +1,185 @@ +name: st-modulator +channels: + - xformers + - pytorch + - nvidia + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotli-python=1.0.9=py310h6a678d5_8 + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2024.7.2=h06a4308_0 + - certifi=2024.7.4=py310h06a4308_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.4.127=h6a678d5_0 + - cuda-runtime=12.1.0=0 + - cuda-version=12.4=hbda6634_3 + - ffmpeg=4.2.2=h20bf706_0 + - filelock=3.13.1=py310h06a4308_0 + - freetype=2.12.1=h4a9f257_0 + - gmp=6.2.1=h295c915_3 + - gmpy2=2.1.2=py310heeb90bb_0 + - gnutls=3.6.15=he1e5248_0 + - idna=3.7=py310h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - jinja2=3.1.4=py310h06a4308_0 + - jpeg=9e=h5eee18b_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.9.1.3=h99ab3db_1 + - libcurand=10.3.5.147=h99ab3db_1 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdeflate=1.17=h5eee18b_1 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libnpp=12.0.2.50=0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libopus=1.3.1=h7b6447c_0 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libvpx=1.7.0=h439df22_0 + - libwebp-base=1.3.2=h5eee18b_0 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_1 + - markupsafe=2.1.3=py310h5eee18b_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py310h5eee18b_1 + - mkl_fft=1.3.8=py310h5eee18b_0 + - mkl_random=1.2.4=py310hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.3.0=py310h06a4308_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - networkx=3.3=py310h06a4308_0 + - numpy=1.26.4=py310h5f9d8c6_0 + - numpy-base=1.26.4=py310hb5e798b_0 + - ocl-icd=2.3.2=h5eee18b_1 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.4.0=h9ca470c_2 + - openssl=3.0.14=h5eee18b_0 + - pillow=10.3.0=py310h5eee18b_0 + - pip=24.0=py310h06a4308_0 + - pysocks=1.7.1=py310h06a4308_0 + - python=3.10.13=h955ad1f_0 + - pytorch=2.3.1=py3.10_cuda12.1_cudnn8.9.2_0 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0.1=py310h5eee18b_0 + - readline=8.2=h5eee18b_0 + - requests=2.32.2=py310h06a4308_0 + - setuptools=69.5.1=py310h06a4308_0 + - sqlite=3.45.3=h5eee18b_0 + - sympy=1.12=py310h06a4308_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.14=h39e8969_0 + - torchaudio=2.3.1=py310_cu121 + - torchtriton=2.3.1=py310 + - torchvision=0.18.1=py310_cu121 + - typing_extensions=4.11.0=py310h06a4308_0 + - tzdata=2024a=h04d1e81_0 + - urllib3=2.2.2=py310h06a4308_0 + - wheel=0.43.0=py310h06a4308_0 + - x264=1!157.20191217=h7b6447c_0 + - xformers=0.0.27=py310_cu12.1.0_pyt2.3.1 + - xz=5.4.6=h5eee18b_1 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.13=h5eee18b_1 + - zstd=1.5.5=hc292b87_2 + - pip: + - absl-py==2.1.0 + - accelerate==0.24.1 + - antlr4-python3-runtime==4.9.3 + - asttokens==2.4.1 + - bitsandbytes==0.35.4 + - click==8.1.7 + - coloredlogs==15.0.1 + - contourpy==1.2.1 + - cycler==0.12.1 + - decorator==5.1.1 + - decord==0.6.0 + - diffusers==0.19.0 + - einops==0.8.0 + - exceptiongroup==1.2.1 + - executing==2.0.1 + - flatbuffers==24.3.25 + - fonttools==4.53.1 + - fsspec==2024.6.1 + - ftfy==6.2.0 + - grpcio==1.65.0 + - huggingface-hub==0.17.3 + - humanfriendly==10.0 + - imageio==2.34.2 + - imageio-ffmpeg==0.5.1 + - importlib-metadata==8.0.0 + - ipython==8.26.0 + - jedi==0.19.1 + - kiwisolver==1.4.5 + - markdown==3.6 + - matplotlib==3.9.1 + - matplotlib-inline==0.1.7 + - modelcards==0.1.6 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==8.9.2.26 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.20.5 + - nvidia-nvjitlink-cu12==12.5.82 + - nvidia-nvtx-cu12==12.1.105 + - omegaconf==2.3.0 + - onnxruntime==1.18.1 + - onnxruntime-gpu==1.18.1 + - opencv-python==4.10.0.84 + - packaging==24.1 + - parso==0.8.4 + - pexpect==4.9.0 + - prompt-toolkit==3.0.47 + - protobuf==4.25.3 + - psutil==6.0.0 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pyav==12.2.0 + - pygments==2.18.0 + - pyparsing==3.1.2 + - python-dateutil==2.9.0.post0 + - regex==2024.5.15 + - safetensors==0.4.3 + - six==1.16.0 + - stack-data==0.6.3 + - tensorboard==2.17.0 + - tensorboard-data-server==0.7.2 + - tokenizers==0.14.1 + - torch==2.3.1+cu121 + - tqdm==4.66.4 + - traitlets==5.14.3 + - transformers==4.35.0 + - triton==2.3.1 + - wcwidth==0.2.13 + - werkzeug==3.0.3 + - zipp==3.19.2 +prefix: /Data/env/anaconda3/envs/flatten diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..60b78c906de80c2795dcff5a7fa084011a454298 --- /dev/null +++ b/test.py @@ -0,0 +1,413 @@ +import os +from glob import glob +import copy +from typing import Optional,Dict +from tqdm.auto import tqdm +from omegaconf import OmegaConf +import click + +import torch +import torch.utils.data +import torch.utils.checkpoint + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DDIMInverseScheduler, +) +from diffusers.utils.import_utils import is_xformers_available +from transformers import AutoTokenizer, CLIPTextModel +from einops import rearrange + +from video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel +#from video_diffusion.models.unet import UNet3DConditionModel +from video_diffusion.data.dataset import ImageSequenceDataset +from video_diffusion.common.util import get_time_string, get_function_args +from video_diffusion.common.logger import get_logger_config_path +from video_diffusion.common.image_util import log_train_samples,log_infer_samples,save_tensor_images_and_video,visualize_check_downsample_keypoints,sample_trajectories,save_videos_grid,sample_trajectories_new +from video_diffusion.common.instantiate_from_config import instantiate_from_config +from video_diffusion.pipelines.validation_loop import SampleLogger + +# logger = get_logger(__name__) + +from video_diffusion.models.controlnet3d import ControlNetModel +from annotator.util import get_control, HWC3 +import numpy as np +import imageio +import torchvision +import cv2 +from torchvision import transforms +from PIL import Image +import time + +def collate_fn(examples): + """Concat a batch of sampled image in dataloader + """ + batch = { + "prompt_ids": torch.cat([example["prompt_ids"] for example in examples], dim=0), + "images": torch.stack([example["images"] for example in examples]), + "masks": torch.cat([example["masks"] for example in examples]), + "layouts": torch.cat([example["layouts"] for example in examples]), + } + return batch + +def test( + config: str, + pretrained_model_path: str, + dataset_config: Dict, + logdir: str = None, + editing_config: Optional[Dict] = None, + control_config: Optional[Dict] = None, + test_pipeline_config: Optional[Dict] = None, + gradient_accumulation_steps: int = 1, + seed: Optional[int] = None, + mixed_precision: Optional[str] = "fp16", + batch_size: int = 1, + model_config: dict={}, + verbose: bool=True, + **kwargs + +): + args = get_function_args() + + time_string = get_time_string() + if logdir is None: + logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '') + # logdir += f"_{time_string}" + + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision=mixed_precision, + ) + if accelerator.is_main_process: + os.makedirs(logdir, exist_ok=True) + OmegaConf.save(args, os.path.join(logdir, "config.yml")) + logger = get_logger_config_path(logdir) + + if seed is not None: + set_seed(seed) + + # Load the tokenizer + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_path, + subfolder="tokenizer", + use_fast=False, + ) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_path, + subfolder="text_encoder", + ) + + vae = AutoencoderKL.from_pretrained( + pretrained_model_path, + subfolder="vae", + ) + + # unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet") + unet = UNetPseudo3DConditionModel.from_2d_model( + os.path.join(pretrained_model_path, "unet"), model_config=model_config + ) + + pretrained_controlnet_path = control_config['pretrained_controlnet_path'] + controlnet= ControlNetModel.from_pretrained_2d(pretrained_controlnet_path) + + + if 'target' not in test_pipeline_config: + test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline' + + pipeline = instantiate_from_config( + test_pipeline_config, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=DDIMScheduler.from_pretrained( + pretrained_model_path, + subfolder="scheduler", + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ), + inverse_scheduler=DDIMInverseScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler"), + logdir=logdir, + ) + pipeline.scheduler.set_timesteps(editing_config['num_inference_steps']) + # pipeline.set_progress_bar_config(disable=True) + #pipeline.print_pipeline(logger) + + if is_xformers_available(): + try: + pipeline.enable_xformers_memory_efficient_attention() + except Exception as e: + logger.warning( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + controlnet.requires_grad_(False) + + print("org prompt input",dataset_config["prompt"]) + print("edit prompt input",editing_config["editing_prompts"]) + + prompt_ids = tokenizer( + dataset_config["prompt"], + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + video_dataset = ImageSequenceDataset(**dataset_config, prompt_ids=prompt_ids) + + train_dataloader = torch.utils.data.DataLoader( + video_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=4, + collate_fn=collate_fn, + ) + train_sample_save_path = os.path.join(logdir, "infer_samples") + log_infer_samples(save_path=train_sample_save_path, infer_dataloader=train_dataloader) + + unet, controlnet, train_dataloader = accelerator.prepare( + unet, controlnet, train_dataloader) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + print('use fp16') + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # These models are only used for inference, keeping weights in full precision is not required. + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("video") # , config=vars(args)) + logger.info("***** wait to fix the logger path *****") + + if editing_config is not None and accelerator.is_main_process: + # validation_sample_logger = P2pSampleLogger(**editing_config, logdir=logdir, source_prompt=dataset_config['prompt']) + validation_sample_logger = SampleLogger(**editing_config, logdir=logdir) + + def make_data_yielder(dataloader): + while True: + for batch in dataloader: + yield batch + accelerator.wait_for_everyone() + + train_data_yielder = make_data_yielder(train_dataloader) + + + batch = next(train_data_yielder) + # if editing_config.get('use_invertion_latents', False): + # Precompute the latents for this video to align the initial latents in training and test + assert batch["images"].shape[0] == 1, "Only support, overfiting on a single video" + # we only inference for latents, no training + + + + + ######precompute control condition########## + + images = batch["images"] # b c f h w, b=1 + b, c, f, height ,width = images.shape + images = (images+1.0)*127.5 # norm back + + ## save source video + + save_videos_grid(batch["images"].cpu(),os.path.join(logdir,"source_video.mp4"),rescale=True) + + images = rearrange(images.to(dtype=torch.float32), "b c f h w -> (b f) h w c") + + control_type = control_config['control_type'] + apply_control = get_control(control_type) + + control = [] + for i in images: + img = i.cpu().numpy() + i = img.astype(np.uint8) + + if control_type == 'canny': + detected_map = apply_control(i, control_config['low_threshold'], control_config['high_threshold']) + elif control_type == 'openpose': + detected_map = apply_control(i, hand=control_config['hand'], face=control_config['face']) + # keypoint.append(candidate_canvas_dict['candidate']) + elif control_type == 'dwpose': + detected_map = apply_control(i, hand=control_config['hand'], face=control_config['face']) + elif control_type == 'depth_zoe': + detected_map = apply_control(i) + elif control_type == 'depth': + detected_map,_ = apply_control(i) + elif control_type == 'hed' or control_type == 'seg': + detected_map = apply_control(i) + elif control_type == 'scribble': + i = i + detected_map = np.zeros_like(i, dtype=np.uint8) + detected_map[np.min(i, axis=2) < control_config.value] = 255 + elif control_type == 'normal': + _, detected_map = apply_control(i, bg_th=control_config['bg_threshold']) + elif control_type == 'mlsd': + detected_map = apply_control(i, control_config['value_threshold'], control_config['distance_threshold']) + else: + raise ValueError(control_type) + control.append(HWC3(detected_map)) + + control = np.stack(control) + control = np.array(control).astype(np.float32) / 255.0 + control = torch.from_numpy(control).to(accelerator.device) + control = control.unsqueeze(0) #[f h w c] -> [b f h w c ] + control = rearrange(control, "b f h w c -> b c f h w") + control = control.to(weight_dtype) + batch['control'] = control + + control_save = control.cpu().float() + + print("save control") + + control_save_dir = os.path.join(logdir, "control") + + save_tensor_images_and_video(control_save, control_save_dir) + + # compute optical flows and sample trajectories + trajectories = sample_trajectories_new(os.path.join(logdir, "source_video.mp4"),accelerator.device,height,width) + + torch.cuda.empty_cache() + + for k in trajectories.keys(): + trajectories[k] = trajectories[k].to(accelerator.device) + + downsample_height, downsample_width = height//8, width//8 + # The externally specified flatten_res + flatten_res = editing_config['flatten_res'] # This could be [1] or [1, 2], etc. + + # Generate the corresponding resolutions + flatten_resolutions = [ + (downsample_height // factor, downsample_width // factor) + for factor in flatten_res + ] + + # Update the editing_config dictionary + editing_config['flatten_res'] = flatten_resolutions + print('flatten res:',editing_config['flatten_res']) + all_start = time.time() + ###ddim inversion scheduler end + + if editing_config['use_freeu']: + from video_diffusion.prompt_attention.free_lunch_utils import apply_freeu + apply_freeu(pipeline, b1=1.2, b2=1.5, s1=1.0, s2=1.0) + if editing_config.get('use_invertion_latents', False): + # Precompute the latents for this video to align the initial latents in training and test + logger.info("use inversion latents") + assert batch["images"].shape[0] == 1, "Only support, overfiting on a single video" + latents, attn_inversion_dict = pipeline.prepare_latents_ddim_inverted( + image=rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w"), + batch_size = 1, + source_prompt = dataset_config.prompt, + do_classifier_free_guidance=True, + control=batch['control'], controlnet_conditioning_scale=control_config['controlnet_conditioning_scale'], + use_pnp=editing_config['use_pnp'], + trajs=trajectories, + old_qk=editing_config["old_qk"], + flatten_res=editing_config['flatten_res'] + ) + + batch['ddim_init_latents'] = latents + print("use inversion latents") + + else: + batch['ddim_init_latents'] = None + + + ########### end of code for ddim inversion########### + vae.eval() + text_encoder.eval() + unet.eval() + controlnet.eval() + + # with accelerator.accumulate(unet): + # Convert images to latent space + images = batch["images"].to(dtype=weight_dtype) + images = rearrange(images, "b c f h w -> (b f) c h w") + + masks = batch["masks"].to(dtype=weight_dtype) + b = batch_size + masks = rearrange(masks, f"c f h w -> {b} c f h w") + + layouts = batch["layouts"].to(dtype=weight_dtype) #layouts = f s c h w + + if accelerator.is_main_process: + + if validation_sample_logger is not None: + unet.eval() + validation_sample_logger.log_sample_images( + image=images, # torch.Size([8, 3, 512, 512]) + masks = masks, + layouts = layouts, + pipeline=pipeline, + device=accelerator.device, + step=0, + latents = batch['ddim_init_latents'], + control = batch['control'], + controlnet_conditioning_scale = control_config['controlnet_conditioning_scale'], + blending_percentage = editing_config["blending_percentage"], + trajs=trajectories, + flatten_res = editing_config['flatten_res'], + negative_prompt=[dataset_config['negative_promot']], + source_prompt=dataset_config.prompt, + inject_step=editing_config["inject_step"], + old_qk=editing_config["old_qk"], + use_pnp = editing_config['use_pnp'], + attn_inversion_dict = attn_inversion_dict, + ) + + accelerator.end_training() + +@click.command() +@click.option("--config", type=str, default="config/shape/exp_config/single_object/tennis_3.yaml") +def run(config): + Omegadict = OmegaConf.load(config) + if 'unet' in os.listdir(Omegadict['pretrained_model_path']): + test(config=config, **Omegadict) + else: + # Go through all ckpt if possible + checkpoint_list = sorted(glob(os.path.join(Omegadict['pretrained_model_path'], 'checkpoint_*'))) + print('checkpoint to evaluate:') + for checkpoint in checkpoint_list: + epoch = checkpoint.split('_')[-1] + + for checkpoint in tqdm(checkpoint_list): + epoch = checkpoint.split('_')[-1] + if 'pretrained_epoch_list' not in Omegadict or int(epoch) in Omegadict['pretrained_epoch_list']: + print(f'Evaluate {checkpoint}') + # Update saving dir and ckpt + Omegadict_checkpoint = copy.deepcopy(Omegadict) + Omegadict_checkpoint['pretrained_model_path'] = checkpoint + + if 'logdir' not in Omegadict_checkpoint: + logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '') + logdir += f"/{os.path.basename(checkpoint)}" + + Omegadict_checkpoint['logdir'] = logdir + print(f'Saving at {logdir}') + + test(config=config, **Omegadict_checkpoint) + + +if __name__ == "__main__": + run() diff --git a/test.sh b/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..1c40f2ca618bb70cbaafc8e30a7da7421411ea09 --- /dev/null +++ b/test.sh @@ -0,0 +1,2 @@ +export CUDA_VISIBLE_DEVICES=0 +accelerate launch test.py --config config/run_two_man.yaml \ No newline at end of file diff --git a/video_diffusion/common/__pycache__/image_util.cpython-310.pyc b/video_diffusion/common/__pycache__/image_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c47c260963bce0688eeefa8375df49077c49c960 Binary files /dev/null and b/video_diffusion/common/__pycache__/image_util.cpython-310.pyc differ diff --git a/video_diffusion/common/__pycache__/image_util.cpython-38.pyc b/video_diffusion/common/__pycache__/image_util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22b475e2a4a738bea02e098ed08dbf1755adc7a1 Binary files /dev/null and b/video_diffusion/common/__pycache__/image_util.cpython-38.pyc differ diff --git a/video_diffusion/common/__pycache__/image_util.cpython-39.pyc b/video_diffusion/common/__pycache__/image_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43a36aa9405ab6804353af839ba1cd380a4aa3fb Binary files /dev/null and b/video_diffusion/common/__pycache__/image_util.cpython-39.pyc differ diff --git a/video_diffusion/common/__pycache__/instantiate_from_config.cpython-310.pyc b/video_diffusion/common/__pycache__/instantiate_from_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..986925c16a3dc71bd10f98c768d9baf76823a879 Binary files /dev/null and b/video_diffusion/common/__pycache__/instantiate_from_config.cpython-310.pyc differ diff --git a/video_diffusion/common/__pycache__/instantiate_from_config.cpython-38.pyc b/video_diffusion/common/__pycache__/instantiate_from_config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbf4ec543da12f5cd2663c69e9dbe26d4c666ba9 Binary files /dev/null and b/video_diffusion/common/__pycache__/instantiate_from_config.cpython-38.pyc differ diff --git a/video_diffusion/common/__pycache__/instantiate_from_config.cpython-39.pyc b/video_diffusion/common/__pycache__/instantiate_from_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b309647b97a89f8fdc6ef31689e92c6a480c84d Binary files /dev/null and b/video_diffusion/common/__pycache__/instantiate_from_config.cpython-39.pyc differ diff --git a/video_diffusion/common/__pycache__/logger.cpython-310.pyc b/video_diffusion/common/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a06dca25815e870302256340966d7725345ee9c Binary files /dev/null and b/video_diffusion/common/__pycache__/logger.cpython-310.pyc differ diff --git a/video_diffusion/common/__pycache__/logger.cpython-38.pyc b/video_diffusion/common/__pycache__/logger.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afc5dde1936593431480c68fc97657d9b77ca872 Binary files /dev/null and b/video_diffusion/common/__pycache__/logger.cpython-38.pyc differ diff --git a/video_diffusion/common/__pycache__/logger.cpython-39.pyc b/video_diffusion/common/__pycache__/logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dabae565f049be4563cd4d1e3799254a9098b054 Binary files /dev/null and b/video_diffusion/common/__pycache__/logger.cpython-39.pyc differ diff --git a/video_diffusion/common/__pycache__/sd_study_utils.cpython-310.pyc b/video_diffusion/common/__pycache__/sd_study_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f9c8d175ad9ecda1fb564fc2fae93cf9bbead51 Binary files /dev/null and b/video_diffusion/common/__pycache__/sd_study_utils.cpython-310.pyc differ diff --git a/video_diffusion/common/__pycache__/util.cpython-310.pyc b/video_diffusion/common/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b80d994d810640c72c7319076cbf529851c11d13 Binary files /dev/null and b/video_diffusion/common/__pycache__/util.cpython-310.pyc differ diff --git a/video_diffusion/common/__pycache__/util.cpython-38.pyc b/video_diffusion/common/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e12abcdd560c07f019452a05b7a4a8ae5bf4242 Binary files /dev/null and b/video_diffusion/common/__pycache__/util.cpython-38.pyc differ diff --git a/video_diffusion/common/__pycache__/util.cpython-39.pyc b/video_diffusion/common/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..911b9255eee9e892afb9948b82ff359f61cf450e Binary files /dev/null and b/video_diffusion/common/__pycache__/util.cpython-39.pyc differ diff --git a/video_diffusion/common/image_util.py b/video_diffusion/common/image_util.py new file mode 100644 index 0000000000000000000000000000000000000000..172284337a0c8056d31e397a8a509f6ce906fd12 --- /dev/null +++ b/video_diffusion/common/image_util.py @@ -0,0 +1,730 @@ +import os +import math +import textwrap + +import imageio +import numpy as np +from typing import Sequence +import requests +import cv2 +from PIL import Image, ImageDraw, ImageFont + +import torch +from torchvision import transforms +from einops import rearrange +import torchvision +import imageio + +import torchvision.transforms.functional as F +import random +from scipy.ndimage import binary_dilation +import sys +sys.path.append('/home/xianyang/Data/code/FateZero/video_diffusion/gmflow') +from gmflow.gmflow import GMFlow + +IMAGE_EXTENSION = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp", ".JPEG") + +FONT_URL = "https://raw.github.com/googlefonts/opensans/main/fonts/ttf/OpenSans-Regular.ttf" +FONT_PATH = "./docs/OpenSans-Regular.ttf" + +np.random.seed(200) +_palette = ((np.random.random((3*255))*0.7+0.3)*255).astype(np.uint8).tolist() +_palette = [0,0,0]+_palette + +def save_prediction(pred_mask,output_dir,file_name): + save_mask = Image.fromarray(pred_mask.astype(np.uint8)) + save_mask = save_mask.convert(mode='P') + save_mask.putpalette(_palette) + save_mask.save(os.path.join(output_dir,file_name)) +def colorize_mask(pred_mask): + save_mask = Image.fromarray(pred_mask.astype(np.uint8)) + save_mask = save_mask.convert(mode='P') + save_mask.putpalette(_palette) + save_mask = save_mask.convert(mode='RGB') + return np.array(save_mask) +def draw_mask(img, mask, alpha=0.5, id_countour=False): + img_mask = np.zeros_like(img) + img_mask = img + if id_countour: + # very slow ~ 1s per image + obj_ids = np.unique(mask) + obj_ids = obj_ids[obj_ids!=0] + + for id in obj_ids: + # Overlay color on binary mask + if id <= 255: + color = _palette[id*3:id*3+3] + else: + color = [0,0,0] + foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color) + binary_mask = (mask == id) + + # Compose image + img_mask[binary_mask] = foreground[binary_mask] + + countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask + img_mask[countours, :] = 0 + else: + binary_mask = (mask!=0) + countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask + foreground = img*(1-alpha)+colorize_mask(mask)*alpha + img_mask[binary_mask] = foreground[binary_mask] + img_mask[countours,:] = 0 + + return img_mask.astype(img.dtype) + + + + +def pad(image: Image.Image, top=0, right=0, bottom=0, left=0, color=(255, 255, 255)) -> Image.Image: + new_image = Image.new(image.mode, (image.width + right + left, image.height + top + bottom), color) + new_image.paste(image, (left, top)) + return new_image + + +def download_font_opensans(path=FONT_PATH): + font_url = FONT_URL + response = requests.get(font_url) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "wb") as f: + f.write(response.content) + + +def annotate_image_with_font(image: Image.Image, text: str, font: ImageFont.FreeTypeFont) -> Image.Image: + image_w = image.width + _, _, text_w, text_h = font.getbbox(text) + line_size = math.floor(len(text) * image_w / text_w) + + lines = textwrap.wrap(text, width=line_size) + padding = text_h * len(lines) + image = pad(image, top=padding + 3) + + ImageDraw.Draw(image).text((0, 0), "\n".join(lines), fill=(0, 0, 0), font=font) + return image + + +def annotate_image(image: Image.Image, text: str, font_size: int = 15): + if not os.path.isfile(FONT_PATH): + download_font_opensans() + font = ImageFont.truetype(FONT_PATH, size=font_size) + return annotate_image_with_font(image=image, text=text, font=font) + + +def make_grid(images: Sequence[Image.Image], rows=None, cols=None) -> Image.Image: + if isinstance(images[0], np.ndarray): + images = [Image.fromarray(i) for i in images] + + if rows is None: + assert cols is not None + rows = math.ceil(len(images) / cols) + else: + cols = math.ceil(len(images) / rows) + + w, h = images[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + for i, image in enumerate(images): + if image.size != (w, h): + image = image.resize((w, h)) + grid.paste(image, box=(i % cols * w, i // cols * h)) + return grid + + +def save_images_as_gif( + images: Sequence[Image.Image], + save_path: str, + loop=0, + duration=100, + optimize=False, +) -> None: + + images[0].save( + save_path, + save_all=True, + append_images=images[1:], + optimize=optimize, + loop=loop, + duration=duration, + ) + +def save_images_as_mp4( + images: Sequence[Image.Image], + save_path: str, +) -> None: + + writer_edit = imageio.get_writer( + save_path, + fps=10) + for i in images: + init_image = i.convert("RGB") + writer_edit.append_data(np.array(init_image)) + writer_edit.close() + + +def save_tensor_images_and_video(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=10): + os.makedirs(path, exist_ok=True) + + # Rearrange video tensor for easier processing + videos = rearrange(videos, "b c t h w -> t b c h w") + + # Lists to store each frame for saving as images and creating a video + frame_list_for_images = [] + + for i, x in enumerate(videos): + # Create a grid of images for this frame + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + + if rescale: + x = (x + 1.0) / 2.0 # Rescale from [-1, 1] to [0, 1] + + x = (x * 255).numpy().astype(np.uint8) + + # Save individual frame as image + save_path_image = os.path.join(path, f"{i}.jpg") + imageio.imsave(save_path_image, x) + + # Append to frame lists + frame_list_for_images.append(x) + + # Save the frames as a video + save_path_video = os.path.join(path, "control.mp4") + imageio.mimwrite(save_path_video, frame_list_for_images, fps=fps) + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, fps=fps) + + +def save_images_as_folder( + images: Sequence[Image.Image], + save_path: str, +) -> None: + os.makedirs(save_path, exist_ok=True) + for index, image in enumerate(images): + init_image = image + if len(np.array(init_image).shape) == 3: + cv2.imwrite(os.path.join(save_path, f"{index:05d}.jpg"), np.array(init_image)[:, :, ::-1]) + else: + cv2.imwrite(os.path.join(save_path, f"{index:05d}.jpg"), np.array(init_image)) + +def log_infer_samples( + infer_dataloader, + save_path, + num_batch: int = 4, + fps: int = 8, + save_input=True, +): + infer_samples = [] + infer_masks = [] + infer_merge_masks = [] + for idx, batch in enumerate(infer_dataloader): + if idx >= num_batch: + break + infer_samples.append(batch["images"]) + infer_masks.append(batch["layouts"]) + infer_merge_masks.append(batch["masks"]) + + infer_samples = torch.cat(infer_samples).numpy() + _,_,frames,height,width = infer_samples.shape + infer_samples = rearrange(infer_samples, "b c f h w -> b f h w c") + print('infer_samples',infer_samples.shape) + infer_samples = (infer_samples * 0.5 + 0.5).clip(0, 1) + # infer_samples = numpy_batch_seq_to_pil(infer_samples) + # infer_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(infer_samples))))) for images in zip(*infer_samples)] + infer_merge_masks = torch.cat(infer_merge_masks).unsqueeze(0) + infer_masks = torch.cat(infer_masks) + # f, s, c, h ,w + + infer_masks = rearrange(infer_masks.squeeze(2), "f s h w -> f h w s") + + # 添加一个全为0的mask到第0维 + zero_mask = torch.zeros(infer_masks.shape[0], infer_masks.shape[1], infer_masks.shape[2], 1) + infer_masks = torch.cat((zero_mask, infer_masks), dim=-1) + infer_masks = torch.argmax(infer_masks, axis=-1).numpy() + + masked_frames = [] + for frame_idx in range(frames): + image = np.array(infer_samples[0][frame_idx]) + mask = infer_masks[frame_idx] + mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) + + image = (image * 255).astype(np.uint8) + masked_frame = draw_mask(image, mask, id_countour=False) + masked_frames.append(masked_frame) + #infer_samples_save = rearrange(torch.tensor(infer_samples),'b t h w c -> b c t h w') + if save_input: + infer_samples = numpy_batch_seq_to_pil(infer_samples) + infer_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(infer_samples))))) for images in zip(*infer_samples)] + save_gif_mp4_folder_type(infer_samples, os.path.join(save_path, 'input.gif')) + imageio.mimsave(os.path.join(save_path, 'masked_video.mp4'),masked_frames,fps=fps) + save_videos_grid(infer_merge_masks,os.path.join(save_path, 'merged_masks.mp4'), fps=fps) + + + +def log_train_samples( + train_dataloader, + save_path, + num_batch: int = 4, +): + train_samples = [] + for idx, batch in enumerate(train_dataloader): + if idx >= num_batch: + break + train_samples.append(batch["images"]) + + train_samples = torch.cat(train_samples).numpy() + train_samples = rearrange(train_samples, "b c f h w -> b f h w c") + train_samples = (train_samples * 0.5 + 0.5).clip(0, 1) + train_samples = numpy_batch_seq_to_pil(train_samples) + train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)] + # save_images_as_gif(train_samples, save_path) + save_gif_mp4_folder_type(train_samples, save_path) + +def log_train_reg_samples( + train_dataloader, + save_path, + num_batch: int = 4, +): + train_samples = [] + for idx, batch in enumerate(train_dataloader): + if idx >= num_batch: + break + train_samples.append(batch["class_images"]) + + train_samples = torch.cat(train_samples).numpy() + train_samples = rearrange(train_samples, "b c f h w -> b f h w c") + train_samples = (train_samples * 0.5 + 0.5).clip(0, 1) + train_samples = numpy_batch_seq_to_pil(train_samples) + train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)] + # save_images_as_gif(train_samples, save_path) + save_gif_mp4_folder_type(train_samples, save_path) + + +def save_gif_mp4_folder_type(images, save_path, save_gif=True): + + + if isinstance(images[0], np.ndarray): + images = [Image.fromarray(i) for i in images] + elif isinstance(images[0], torch.Tensor): + images = [transforms.ToPILImage()(i.cpu().clone()[0]) for i in images] + save_path_mp4 = save_path.replace('gif', 'mp4') + save_path_folder = save_path.replace('.gif', '') + os.makedirs(save_path_folder, exist_ok=True) + if save_gif: save_images_as_gif(images, save_path) + save_images_as_mp4(images, save_path_mp4) + save_images_as_folder(images, save_path_folder) + +# copy from video_diffusion/pipelines/stable_diffusion.py +def numpy_seq_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + +# copy from diffusers-0.11.1/src/diffusers/pipeline_utils.py +def numpy_batch_seq_to_pil(images): + pil_images = [] + for sequence in images: + pil_images.append(numpy_seq_to_pil(sequence)) + return pil_images + + +def downsample_image(image, target_size): + image = Image.fromarray(image) + resized_image = image.resize(target_size, Image.ANTIALIAS) + return np.array(resized_image) + +def visualize_check_downsample_keypoints(images, keypoint_data, target_res=(32, 32),final_res=(512, 512)): + # 预处理帧列表 + processed_frames = [] + + # 遍历每一帧 + for frame_idx, frame_tensor in enumerate(images): + # 将张量转换为 NumPy 数组 + # print("frame",frame_tensor.shape) + frame = frame_tensor.cpu().numpy().astype('uint8') + + # 下采样图片 + downsampled_frame = downsample_image(frame, target_res) + + # 绘制关键点 + for keypoint in keypoint_data[frame_idx]: + h_coord, w_coord = keypoint + downsampled_frame[h_coord, w_coord] = [255, 0, 0] # 使用红色标记关键点 + + # 将处理过的帧重新调整到最终分辨率 + final_frame = downsample_image(downsampled_frame, final_res) + + # 将处理过的帧添加到列表中 + processed_frames.append(final_frame) + + # 使用 imageio 保存处理过的帧为视频 + output_video_path = "./down_sample_check_hockey.mp4" + imageio.mimsave(output_video_path, processed_frames, fps=10) + + +"""optical flow and trajectories sampling""" +def preprocess(img1_batch, img2_batch, transforms, height,width): + img1_batch = F.resize(img1_batch, size=[height, width], antialias=False) + img2_batch = F.resize(img2_batch, size=[height, width], antialias=False) + return transforms(img1_batch, img2_batch) + +def keys_with_same_value(dictionary): + result = {} + for key, value in dictionary.items(): + if value not in result: + result[value] = [key] + else: + result[value].append(key) + + conflict_points = {} + for k in result.keys(): + if len(result[k]) > 1: + conflict_points[k] = result[k] + return conflict_points + +def find_duplicates(input_list): + seen = set() + duplicates = set() + + for item in input_list: + if item in seen: + duplicates.add(item) + else: + seen.add(item) + + return list(duplicates) + +def neighbors_index(point, window_size, H, W): + """return the spatial neighbor indices""" + t, x, y = point + neighbors = [] + for i in range(-window_size, window_size + 1): + for j in range(-window_size, window_size + 1): + if i == 0 and j == 0: + continue + if x + i < 0 or x + i >= H or y + j < 0 or y + j >= W: + continue + neighbors.append((t, x + i, y + j)) + return neighbors + + + +@torch.no_grad() +def sample_trajectories(video_path, device,height,width): + from torchvision.models.optical_flow import Raft_Large_Weights + from torchvision.models.optical_flow import raft_large + + weights = Raft_Large_Weights.DEFAULT + transforms = weights.transforms() + + frames, _, _ = torchvision.io.read_video(str(video_path), output_format="TCHW") + + clips = list(range(len(frames))) + + model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) + model = model.eval() + + finished_trajectories = [] + + current_frames, next_frames = preprocess(frames[clips[:-1]], frames[clips[1:]], transforms, 512,512) + list_of_flows = model(current_frames.to(device), next_frames.to(device)) + predicted_flows = list_of_flows[-1] + print('predicted_flows',predicted_flows.shape) + predicted_flows = predicted_flows/512 + + resolutions = [64, 32, 16, 8] + res = {} + window_sizes = {64: 2, + 32: 1, + 16: 1, + 8: 1} + + for resolution in resolutions: + print("="*30) + trajectories = {} + predicted_flow_resolu = torch.round(resolution*torch.nn.functional.interpolate(predicted_flows, scale_factor=(resolution/512, resolution/512))) + + T = predicted_flow_resolu.shape[0]+1 + H = predicted_flow_resolu.shape[2] + W = predicted_flow_resolu.shape[3] + + is_activated = torch.zeros([T, H, W], dtype=torch.bool) + + for t in range(T-1): + flow = predicted_flow_resolu[t] + for h in range(H): + for w in range(W): + + if not is_activated[t, h, w]: + is_activated[t, h, w] = True + # this point has not been traversed, start new trajectory + x = h + int(flow[1, h, w]) + y = w + int(flow[0, h, w]) + if x >= 0 and x < H and y >= 0 and y < W: + # trajectories.append([(t, h, w), (t+1, x, y)]) + trajectories[(t, h, w)]= (t+1, x, y) + + conflict_points = keys_with_same_value(trajectories) + for k in conflict_points: + index_to_pop = random.randint(0, len(conflict_points[k]) - 1) + conflict_points[k].pop(index_to_pop) + for point in conflict_points[k]: + if point[0] != T-1: + trajectories[point]= (-1, -1, -1) # stupid padding with (-1, -1, -1) + + active_traj = [] + all_traj = [] + for t in range(T): + pixel_set = {(t, x//H, x%H):0 for x in range(H*W)} + new_active_traj = [] + for traj in active_traj: + if traj[-1] in trajectories: + v = trajectories[traj[-1]] + new_active_traj.append(traj + [v]) + pixel_set[v] = 1 + else: + all_traj.append(traj) + active_traj = new_active_traj + active_traj+=[[pixel] for pixel in pixel_set if pixel_set[pixel] == 0] + all_traj += active_traj + + useful_traj = [i for i in all_traj if len(i)>1] + for idx in range(len(useful_traj)): + if useful_traj[idx][-1] == (-1, -1, -1): + useful_traj[idx] = useful_traj[idx][:-1] + print("how many points in all trajectories for resolution{}?".format(resolution), sum([len(i) for i in useful_traj])) + print("how many points in the video for resolution{}?".format(resolution), T*H*W) + + # validate if there are no duplicates in the trajectories + trajs = [] + for traj in useful_traj: + trajs = trajs + traj + assert len(find_duplicates(trajs)) == 0, "There should not be duplicates in the useful trajectories." + + # check if non-appearing points + appearing points = all the points in the video + all_points = set([(t, x, y) for t in range(T) for x in range(H) for y in range(W)]) + left_points = all_points- set(trajs) + print("How many points not in the trajectories for resolution{}?".format(resolution), len(left_points)) + for p in list(left_points): + useful_traj.append([p]) + print("how many points in all trajectories for resolution{} after pending?".format(resolution), sum([len(i) for i in useful_traj])) + + + longest_length = max([len(i) for i in useful_traj]) + sequence_length = (window_sizes[resolution]*2+1)**2 + longest_length - 1 + + seqs = [] + masks = [] + + # create a dictionary to facilitate checking the trajectories to which each point belongs. + point_to_traj = {} + for traj in useful_traj: + for p in traj: + point_to_traj[p] = traj + + for t in range(T): + for x in range(H): + for y in range(W): + neighbours = neighbors_index((t,x,y), window_sizes[resolution], H, W) + sequence = [(t,x,y)]+neighbours + [(0,0,0) for i in range((window_sizes[resolution]*2+1)**2-1-len(neighbours))] + sequence_mask = torch.zeros(sequence_length, dtype=torch.bool) + sequence_mask[:len(neighbours)+1] = True + + traj = point_to_traj[(t,x,y)].copy() + traj.remove((t,x,y)) + sequence = sequence + traj + [(0,0,0) for k in range(longest_length-1-len(traj))] + sequence_mask[(window_sizes[resolution]*2+1)**2: (window_sizes[resolution]*2+1)**2 + len(traj)] = True + + seqs.append(sequence) + masks.append(sequence_mask) + + seqs = torch.tensor(seqs) + masks = torch.stack(masks) + res["traj{}".format(resolution)] = seqs + res["mask{}".format(resolution)] = masks + return res + + +@torch.no_grad() +def sample_trajectories_new(video_path, device,height,width): + from torchvision.models.optical_flow import Raft_Large_Weights + from torchvision.models.optical_flow import raft_large + + weights = Raft_Large_Weights.DEFAULT + transforms = weights.transforms() + + frames, _, _ = torchvision.io.read_video(str(video_path), output_format="TCHW") + + clips = list(range(len(frames))) + + #=============== GM-flow estimate forward optical flow============# + # model = GMFlow(feature_channels=128, + # num_scales=1, + # upsample_factor=8, + # num_head=1, + # attention_type='swin', + # ffn_dim_expansion=4, + # num_transformer_layers=6, + # ).to(device) + # checkpoint = torch.load('/home/xianyang/Data/code/FRESCO/model/gmflow_sintel-0c07dcb3.pth', map_location=lambda storage, loc: storage) + # weights = checkpoint['model'] if 'model' in checkpoint else checkpoint + # model.load_state_dict(weights, strict=False) + # model.eval() + # finished_trajectories = [] + + # current_frames, next_frames = preprocess(frames[clips[:-1]], frames[clips[1:]], transforms, height,width) + # results_dict = model(current_frames.to(device), next_frames.to(device), attn_splits_list=[2], + # corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True) + # flow_pr = results_dict['flow_preds'][-1] # [2*B, 2, H, W] + # fwd_flows, bwd_flows = flow_pr.chunk(2) # [B, 2, H, W] + # predicted_flows = fwd_flows + #=============== GM-flow estimate forward optical flow============# + + #=============== raft-large estimate forward optical flow============# + model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) + model = model.eval() + finished_trajectories = [] + + current_frames, next_frames = preprocess(frames[clips[:-1]], frames[clips[1:]], transforms, height,width) + list_of_flows = model(current_frames.to(device), next_frames.to(device)) + predicted_flows = list_of_flows[-1] + #=============== raft-large estimate forward optical flow============# + + predicted_flows = predicted_flows/max(height,width) + + resolutions =[(height//8,width//8),(height//16,width//16),(height//32,width//32),(height//64,width//64)] + #resolutions = [64, 32, 16, 8] + res = {} + window_sizes = {(height//8,width//8): 2, + (height//16,width//16): 1, + (height//32,width//32): 1, + (height//64,width//64): 1} + + for resolution in resolutions: + print("="*30) + # print(resolution) + # print('window_sizes[resolution]',window_sizes[resolution]) + trajectories = {} + height_scale_factor = resolution[0] / height + width_scale_factor = resolution[1] / width + predicted_flow_resolu = torch.round(max(resolution[0], resolution[1])*torch.nn.functional.interpolate(predicted_flows, scale_factor=(height_scale_factor, width_scale_factor))) + + T = predicted_flow_resolu.shape[0]+1 + H = predicted_flow_resolu.shape[2] + W = predicted_flow_resolu.shape[3] + + is_activated = torch.zeros([T, H, W], dtype=torch.bool) + + for t in range(T-1): + flow = predicted_flow_resolu[t] + for h in range(H): + for w in range(W): + + if not is_activated[t, h, w]: + is_activated[t, h, w] = True + # this point has not been traversed, start new trajectory + x = h + int(flow[1, h, w]) + y = w + int(flow[0, h, w]) + if x >= 0 and x < H and y >= 0 and y < W: + # trajectories.append([(t, h, w), (t+1, x, y)]) + trajectories[(t, h, w)]= (t+1, x, y) + + conflict_points = keys_with_same_value(trajectories) + for k in conflict_points: + index_to_pop = random.randint(0, len(conflict_points[k]) - 1) + conflict_points[k].pop(index_to_pop) + for point in conflict_points[k]: + if point[0] != T-1: + trajectories[point]= (-1, -1, -1) # stupid padding with (-1, -1, -1) + + active_traj = [] + all_traj = [] + for t in range(T): + pixel_set = {(t, x//H, x%H):0 for x in range(H*W)} + new_active_traj = [] + for traj in active_traj: + if traj[-1] in trajectories: + v = trajectories[traj[-1]] + new_active_traj.append(traj + [v]) + pixel_set[v] = 1 + else: + all_traj.append(traj) + active_traj = new_active_traj + active_traj+=[[pixel] for pixel in pixel_set if pixel_set[pixel] == 0] + all_traj += active_traj + + useful_traj = [i for i in all_traj if len(i)>1] + for idx in range(len(useful_traj)): + if useful_traj[idx][-1] == (-1, -1, -1): + useful_traj[idx] = useful_traj[idx][:-1] + print("how many points in all trajectories for resolution{}?".format(resolution), sum([len(i) for i in useful_traj])) + print("how many points in the video for resolution{}?".format(resolution), T*H*W) + + # validate if there are no duplicates in the trajectories + trajs = [] + for traj in useful_traj: + trajs = trajs + traj + assert len(find_duplicates(trajs)) == 0, "There should not be duplicates in the useful trajectories." + + # check if non-appearing points + appearing points = all the points in the video + all_points = set([(t, x, y) for t in range(T) for x in range(H) for y in range(W)]) + left_points = all_points- set(trajs) + print("How many points not in the trajectories for resolution{}?".format(resolution), len(left_points)) + for p in list(left_points): + useful_traj.append([p]) + print("how many points in all trajectories for resolution{} after pending?".format(resolution), sum([len(i) for i in useful_traj])) + + + longest_length = max([len(i) for i in useful_traj]) + sequence_length = (window_sizes[resolution]*2+1)**2 + longest_length - 1 + + seqs = [] + masks = [] + + # create a dictionary to facilitate checking the trajectories to which each point belongs. + point_to_traj = {} + for traj in useful_traj: + for p in traj: + point_to_traj[p] = traj + + for t in range(T): + for x in range(H): + for y in range(W): + neighbours = neighbors_index((t,x,y), window_sizes[resolution], H, W) + sequence = [(t,x,y)]+neighbours + [(0,0,0) for i in range((window_sizes[resolution]*2+1)**2-1-len(neighbours))] + sequence_mask = torch.zeros(sequence_length, dtype=torch.bool) + sequence_mask[:len(neighbours)+1] = True + + traj = point_to_traj[(t,x,y)].copy() + traj.remove((t,x,y)) + sequence = sequence + traj + [(0,0,0) for k in range(longest_length-1-len(traj))] + sequence_mask[(window_sizes[resolution]*2+1)**2: (window_sizes[resolution]*2+1)**2 + len(traj)] = True + + seqs.append(sequence) + masks.append(sequence_mask) + + seqs = torch.tensor(seqs) + masks = torch.stack(masks) + res["traj{}".format(resolution[0])] = seqs + res["mask{}".format(resolution[0])] = masks + return res \ No newline at end of file diff --git a/video_diffusion/common/instantiate_from_config.py b/video_diffusion/common/instantiate_from_config.py new file mode 100644 index 0000000000000000000000000000000000000000..9c410d1ba6f0073fada0bbdb056cbad4abed3aa9 --- /dev/null +++ b/video_diffusion/common/instantiate_from_config.py @@ -0,0 +1,33 @@ +""" +Copy from stable diffusion +""" +import importlib + + +def instantiate_from_config(config:dict, **args_from_code): + """Util funciton to decompose differenct modules using config + + Args: + config (dict): with key of "target" and "params", better from yaml + static + args_from_code: additional con + + + Returns: + a validation/training pipeline, a module + """ + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict()), **args_from_code) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) diff --git a/video_diffusion/common/logger.py b/video_diffusion/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ed344e5f4d377540e96fbd6dc00f1d9edc7201dd --- /dev/null +++ b/video_diffusion/common/logger.py @@ -0,0 +1,17 @@ +import os +import logging, logging.handlers +from accelerate.logging import get_logger + +def get_logger_config_path(logdir): + # accelerate handles the logger in multiprocessing + logger = get_logger(__name__) + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s:%(levelname)s : %(message)s', + datefmt='%a, %d %b %Y %H:%M:%S', + filename=os.path.join(logdir, 'log.log'), + filemode='w') + chlr = logging.StreamHandler() + chlr.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s : %(message)s')) + logger.logger.addHandler(chlr) + return logger \ No newline at end of file diff --git a/video_diffusion/common/set_seed.py b/video_diffusion/common/set_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..8f30dbf3028fc884adcd3ed0ffb317f2220ac32a --- /dev/null +++ b/video_diffusion/common/set_seed.py @@ -0,0 +1,28 @@ +import os +os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + +import torch +import numpy as np +import random + +from accelerate.utils import set_seed + + +def video_set_seed(seed: int): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: + seed (`int`): The seed to set. + device_specific (`bool`, *optional*, defaults to `False`): + Whether to differ the seed on each device slightly with `self.process_index`. + """ + set_seed(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + # torch.use_deterministic_algorithms(True, warn_only=True) + # [W Context.cpp:82] Warning: efficient_attention_forward_cutlass does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True, warn_only=True)'. You can file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation. (function alertNotDeterministic) + diff --git a/video_diffusion/common/util.py b/video_diffusion/common/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b393ba6745e6737d1476626e95b3a40137e2982d --- /dev/null +++ b/video_diffusion/common/util.py @@ -0,0 +1,73 @@ +import os +import sys +import copy +import inspect +import datetime +from typing import List, Tuple, Optional, Dict + + +def glob_files( + root_path: str, + extensions: Tuple[str], + recursive: bool = True, + skip_hidden_directories: bool = True, + max_directories: Optional[int] = None, + max_files: Optional[int] = None, + relative_path: bool = False, +) -> Tuple[List[str], bool, bool]: + """glob files with specified extensions + + Args: + root_path (str): _description_ + extensions (Tuple[str]): _description_ + recursive (bool, optional): _description_. Defaults to True. + skip_hidden_directories (bool, optional): _description_. Defaults to True. + max_directories (Optional[int], optional): max number of directories to search. Defaults to None. + max_files (Optional[int], optional): max file number limit. Defaults to None. + relative_path (bool, optional): _description_. Defaults to False. + + Returns: + Tuple[List[str], bool, bool]: _description_ + """ + paths = [] + hit_max_directories = False + hit_max_files = False + for directory_idx, (directory, _, fnames) in enumerate(os.walk(root_path, followlinks=True)): + if skip_hidden_directories and os.path.basename(directory).startswith("."): + continue + + if max_directories is not None and directory_idx >= max_directories: + hit_max_directories = True + break + + paths += [ + os.path.join(directory, fname) + for fname in sorted(fnames) + if fname.lower().endswith(extensions) + ] + + if not recursive: + break + + if max_files is not None and len(paths) > max_files: + hit_max_files = True + paths = paths[:max_files] + break + + if relative_path: + paths = [os.path.relpath(p, root_path) for p in paths] + + return paths, hit_max_directories, hit_max_files + + +def get_time_string() -> str: + x = datetime.datetime.now() + return f"{(x.year - 2000):02d}{x.month:02d}{x.day:02d}-{x.hour:02d}{x.minute:02d}{x.second:02d}" + + +def get_function_args() -> Dict: + frame = sys._getframe(1) + args, _, _, values = inspect.getargvalues(frame) + args_dict = copy.deepcopy({arg: values[arg] for arg in args}) + + return args_dict diff --git a/video_diffusion/data/__pycache__/dataset.cpython-310.pyc b/video_diffusion/data/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10d612bfc6e5092ac5468840724a47cc87647f1b Binary files /dev/null and b/video_diffusion/data/__pycache__/dataset.cpython-310.pyc differ diff --git a/video_diffusion/data/__pycache__/dataset.cpython-38.pyc b/video_diffusion/data/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48841f0408fcad6f853c1b5c23193bc2c68171f2 Binary files /dev/null and b/video_diffusion/data/__pycache__/dataset.cpython-38.pyc differ diff --git a/video_diffusion/data/__pycache__/dataset.cpython-39.pyc b/video_diffusion/data/__pycache__/dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64500041e2bf96faf82ee7122b4d6b8abe5d2b03 Binary files /dev/null and b/video_diffusion/data/__pycache__/dataset.cpython-39.pyc differ diff --git a/video_diffusion/data/__pycache__/transform.cpython-310.pyc b/video_diffusion/data/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8efaa2b720abe6b780bbfd6f44fd9c04e25c3d39 Binary files /dev/null and b/video_diffusion/data/__pycache__/transform.cpython-310.pyc differ diff --git a/video_diffusion/data/__pycache__/transform.cpython-38.pyc b/video_diffusion/data/__pycache__/transform.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ce251796c67861c57f7f8c4d7ac7cc9a9724030 Binary files /dev/null and b/video_diffusion/data/__pycache__/transform.cpython-38.pyc differ diff --git a/video_diffusion/data/__pycache__/transform.cpython-39.pyc b/video_diffusion/data/__pycache__/transform.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c0c8704931bb4dbf900fbeba35c17460789c96e Binary files /dev/null and b/video_diffusion/data/__pycache__/transform.cpython-39.pyc differ diff --git a/video_diffusion/data/dataset.py b/video_diffusion/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..965961c4846d91578c108985092c9a30da50fab2 --- /dev/null +++ b/video_diffusion/data/dataset.py @@ -0,0 +1,195 @@ +import os + +import numpy as np +from PIL import Image +from einops import rearrange +from pathlib import Path + +import torch +from torch.utils.data import Dataset + +from .transform import short_size_scale, random_crop, center_crop, offset_crop +from ..common.image_util import IMAGE_EXTENSION +import cv2 + +class ImageSequenceDataset(Dataset): + def __init__( + self, + path: str, + layout_mask_dir: str, + layout_mask_order: list, + prompt_ids: torch.Tensor, + prompt: str, + start_sample_frame: int=0, + n_sample_frame: int = 8, + sampling_rate: int = 1, + stride: int = -1, # only used during tuning to sample a long video + image_mode: str = "RGB", + image_size: int = 512, + crop: str = "center", + + class_data_root: str = None, + class_prompt_ids: torch.Tensor = None, + + offset: dict = { + "left": 0, + "right": 0, + "top": 0, + "bottom": 0 + }, + **args + + ): + self.path = path + self.images = self.get_image_list(path) + # + self.layout_mask_dir = layout_mask_dir + self.layout_mask_order = list(layout_mask_order) + + layout_mask_dir0 = os.path.join(self.layout_mask_dir,self.layout_mask_order[0]) + self.masks_index = self.get_image_list(layout_mask_dir0) + + # + self.n_images = len(self.images) + self.offset = offset + self.start_sample_frame = start_sample_frame + if n_sample_frame < 0: + n_sample_frame = len(self.images) + self.n_sample_frame = n_sample_frame + # local sampling rate from the video + self.sampling_rate = sampling_rate + + self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1 + if self.n_images < self.sequence_length: + raise ValueError(f"self.n_images {self.n_images } < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images }") + + # During tuning if video is too long, we sample the long video every self.stride globally + self.stride = stride if stride > 0 else (self.n_images+1) + self.video_len = (self.n_images - self.sequence_length) // self.stride + 1 + + self.image_mode = image_mode + self.image_size = image_size + crop_methods = { + "center": center_crop, + "random": random_crop, + } + if crop not in crop_methods: + raise ValueError + self.crop = crop_methods[crop] + + self.prompt = prompt + self.prompt_ids = prompt_ids + # Negative prompt for regularization to avoid overfitting during one-shot tuning + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_images_path = sorted(list(self.class_data_root.iterdir())) + self.num_class_images = len(self.class_images_path) + self.class_prompt_ids = class_prompt_ids + + + def __len__(self): + max_len = (self.n_images - self.sequence_length) // self.stride + 1 + + if hasattr(self, 'num_class_images'): + max_len = max(max_len, self.num_class_images) + + return max_len + + def __getitem__(self, index): + return_batch = {} + frame_indices = self.get_frame_indices(index%self.video_len) + frames = [self.load_frame(i) for i in frame_indices] + frames = self.transform(frames) + + layout_ = [] + for layout_name in self.layout_mask_order: + frame_indices = self.get_frame_indices(index%self.video_len) + layout_mask_dir = os.path.join(self.layout_mask_dir,layout_name) + mask = [self._read_mask(layout_mask_dir,i) for i in frame_indices] + masks = np.stack(mask) + layout_.append(masks) + layout_ = np.stack(layout_) + merged_masks = [] + for i in range(int(self.n_sample_frame)): + merged_mask_frame = np.sum(layout_[:,i,:,:,:], axis=0) + merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8) + merged_masks.append(merged_mask_frame) + masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w") + masks = torch.from_numpy(masks).half() + + layouts = rearrange(layout_,"s f c h w -> f s c h w" ) + layouts = torch.from_numpy(layouts).half() + + return_batch.update( + { + "images": frames, + "masks":masks, + "layouts":layouts, + "prompt_ids": self.prompt_ids, + } + ) + + if hasattr(self, 'class_data_root'): + class_index = index % (self.num_class_images - self.n_sample_frame) + class_indices = self.get_class_indices(class_index) + frames = [self.load_class_frame(i) for i in class_indices] + return_batch["class_images"] = self.tensorize_frames(frames) + return_batch["class_prompt_ids"] = self.class_prompt_ids + return return_batch + + def transform(self, frames): + frames = self.tensorize_frames(frames) + frames = offset_crop(frames, **self.offset) + frames = short_size_scale(frames, size=self.image_size) + frames = self.crop(frames, height=self.image_size, width=self.image_size) + return frames + + @staticmethod + def tensorize_frames(frames): + frames = rearrange(np.stack(frames), "f h w c -> c f h w") + return torch.from_numpy(frames).div(255) * 2 - 1 + + def _read_mask(self, mask_path,index: int): + ### read mask by pil + + mask_path = os.path.join(mask_path,f"{index:05d}.png") + + ### read mask by cv2 + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + mask = (mask > 0).astype(np.uint8) + # Determine dynamic destination size + height, width = mask.shape + dest_size = (width // 8, height // 8) + # Resize using nearest neighbor interpolation + mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) #cv2.INTER_CUBIC + mask = mask[np.newaxis, ...] + + return mask + + + def load_frame(self, index): + image_path = os.path.join(self.path, self.images[index]) + return Image.open(image_path).convert(self.image_mode) + + def load_class_frame(self, index): + image_path = self.class_images_path[index] + return Image.open(image_path).convert(self.image_mode) + + def get_frame_indices(self, index): + if self.start_sample_frame is not None: + frame_start = self.start_sample_frame + self.stride * index + else: + frame_start = self.stride * index + return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame)) + + def get_class_indices(self, index): + frame_start = index + return (frame_start + i for i in range(self.n_sample_frame)) + + @staticmethod + def get_image_list(path): + images = [] + for file in sorted(os.listdir(path)): + if file.endswith(IMAGE_EXTENSION): + images.append(file) + return images diff --git a/video_diffusion/data/transform.py b/video_diffusion/data/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..043097b4b1eb0108ba1c5430ddd11702dfe9a9b6 --- /dev/null +++ b/video_diffusion/data/transform.py @@ -0,0 +1,48 @@ +import random + +import torch + + +def short_size_scale(images, size): + h, w = images.shape[-2:] + short, long = (h, w) if h < w else (w, h) + + scale = size / short + long_target = int(scale * long) + + target_size = (size, long_target) if h < w else (long_target, size) + + return torch.nn.functional.interpolate( + input=images, size=target_size, mode="bilinear", antialias=True + ) + + +def random_short_side_scale(images, size_min, size_max): + size = random.randint(size_min, size_max) + return short_size_scale(images, size) + + +def random_crop(images, height, width): + image_h, image_w = images.shape[-2:] + h_start = random.randint(0, image_h - height) + w_start = random.randint(0, image_w - width) + return images[:, :, h_start : h_start + height, w_start : w_start + width] + + +def center_crop(images, height, width): + # offset_crop(images, 0,0, 200, 0) + image_h, image_w = images.shape[-2:] + h_start = (image_h - height) // 2 + w_start = (image_w - width) // 2 + return images[:, :, h_start : h_start + height, w_start : w_start + width] + +def offset_crop(image, left=0, right=0, top=200, bottom=0): + + n, c, h, w = image.shape + left = min(left, w-1) + right = min(right, w - left - 1) + top = min(top, h - 1) + bottom = min(bottom, h - top - 1) + image = image[:, :, top:h-bottom, left:w-right] + + return image \ No newline at end of file diff --git a/video_diffusion/models/__pycache__/attention.cpython-310.pyc b/video_diffusion/models/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca1f0f94143637be1804c7048b0f64b9368f7a04 Binary files /dev/null and b/video_diffusion/models/__pycache__/attention.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/attention.cpython-38.pyc b/video_diffusion/models/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07fbad7843b367d3f1d7ebab77fa9d91e1a8f1ce Binary files /dev/null and b/video_diffusion/models/__pycache__/attention.cpython-38.pyc differ diff --git a/video_diffusion/models/__pycache__/attention.cpython-39.pyc b/video_diffusion/models/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..669bdf0133fa46a9e194da4884b7e8da7ef4e22b Binary files /dev/null and b/video_diffusion/models/__pycache__/attention.cpython-39.pyc differ diff --git a/video_diffusion/models/__pycache__/attention_controlnet.cpython-310.pyc b/video_diffusion/models/__pycache__/attention_controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bd0327f0c32d913f2948d6b83d385db29fe52a2 Binary files /dev/null and b/video_diffusion/models/__pycache__/attention_controlnet.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/controlnet3d.cpython-310.pyc b/video_diffusion/models/__pycache__/controlnet3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07f28411307bb51bbe973164085b20f7836e6ba8 Binary files /dev/null and b/video_diffusion/models/__pycache__/controlnet3d.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/controlnet3d.cpython-38.pyc b/video_diffusion/models/__pycache__/controlnet3d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0209c8cf7eee07a453ea884476e0f863e936e8fa Binary files /dev/null and b/video_diffusion/models/__pycache__/controlnet3d.cpython-38.pyc differ diff --git a/video_diffusion/models/__pycache__/controlnet3d.cpython-39.pyc b/video_diffusion/models/__pycache__/controlnet3d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff6e2a61f0592f44373a9f71e809bd7ab2722004 Binary files /dev/null and b/video_diffusion/models/__pycache__/controlnet3d.cpython-39.pyc differ diff --git a/video_diffusion/models/__pycache__/controlnet_3d_blocks.cpython-310.pyc b/video_diffusion/models/__pycache__/controlnet_3d_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aff9ce902162f4b30a5fb228e615335c63f663e1 Binary files /dev/null and b/video_diffusion/models/__pycache__/controlnet_3d_blocks.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/controlnet_attention.cpython-310.pyc b/video_diffusion/models/__pycache__/controlnet_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..689d5dbb1b0ea1e34b4ae1d30a34a3c5c162520c Binary files /dev/null and b/video_diffusion/models/__pycache__/controlnet_attention.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/controlnet_blocks.cpython-310.pyc b/video_diffusion/models/__pycache__/controlnet_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e7144f65a6ce706dc04e038edd1830f9af8aa60 Binary files /dev/null and b/video_diffusion/models/__pycache__/controlnet_blocks.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/controlnet_unet_blocks.cpython-310.pyc b/video_diffusion/models/__pycache__/controlnet_unet_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f770b8155ba86eb5cf6b550637c412d40bb77670 Binary files /dev/null and b/video_diffusion/models/__pycache__/controlnet_unet_blocks.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/lora.cpython-310.pyc b/video_diffusion/models/__pycache__/lora.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..952d16c0aaea60f081669401ca1145f8fbc9fbbc Binary files /dev/null and b/video_diffusion/models/__pycache__/lora.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/lora.cpython-38.pyc b/video_diffusion/models/__pycache__/lora.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..625efa595834b589a8be461f212dbc43e6a08732 Binary files /dev/null and b/video_diffusion/models/__pycache__/lora.cpython-38.pyc differ diff --git a/video_diffusion/models/__pycache__/lora.cpython-39.pyc b/video_diffusion/models/__pycache__/lora.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a65ae70efca3c5d1b4d227bd35ab777b21062d47 Binary files /dev/null and b/video_diffusion/models/__pycache__/lora.cpython-39.pyc differ diff --git a/video_diffusion/models/__pycache__/motion_module.cpython-310.pyc b/video_diffusion/models/__pycache__/motion_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1400da0431f3e7382f69a2d74ed96b2f500d6090 Binary files /dev/null and b/video_diffusion/models/__pycache__/motion_module.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/resnet.cpython-310.pyc b/video_diffusion/models/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c96b923799d78f9ab084bd8d75ef5e7cb102763 Binary files /dev/null and b/video_diffusion/models/__pycache__/resnet.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/resnet.cpython-38.pyc b/video_diffusion/models/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc00a02cb623980ad9d382af554beb79e3fa00d5 Binary files /dev/null and b/video_diffusion/models/__pycache__/resnet.cpython-38.pyc differ diff --git a/video_diffusion/models/__pycache__/resnet.cpython-39.pyc b/video_diffusion/models/__pycache__/resnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef0bf58f651f8e1374a8dbce9be8289514dad487 Binary files /dev/null and b/video_diffusion/models/__pycache__/resnet.cpython-39.pyc differ diff --git a/video_diffusion/models/__pycache__/resnet_controlnet.cpython-310.pyc b/video_diffusion/models/__pycache__/resnet_controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f3afc60b28c1b1fab15a5b24d3ffbac50b459ff Binary files /dev/null and b/video_diffusion/models/__pycache__/resnet_controlnet.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/unet.cpython-310.pyc b/video_diffusion/models/__pycache__/unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4d54813e8a86bfe80488f718cf9549e474953d6 Binary files /dev/null and b/video_diffusion/models/__pycache__/unet.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-310.pyc b/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4699f2d699da41992da9b7487b93054969f90b28 Binary files /dev/null and b/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-38.pyc b/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..381b68eb30fb13f0319af3b942dc1e7a7cf46503 Binary files /dev/null and b/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-38.pyc differ diff --git a/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-39.pyc b/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..215cf64ec17d6e8cdbc709897e39a9a63d873807 Binary files /dev/null and b/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-39.pyc differ diff --git a/video_diffusion/models/__pycache__/unet_3d_condition.cpython-310.pyc b/video_diffusion/models/__pycache__/unet_3d_condition.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7efbb4612d0ce40c23e10be6427a1897ab57a952 Binary files /dev/null and b/video_diffusion/models/__pycache__/unet_3d_condition.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/unet_3d_condition.cpython-38.pyc b/video_diffusion/models/__pycache__/unet_3d_condition.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dde7ece643268b8fe1f471c1291b2714b2ea97bd Binary files /dev/null and b/video_diffusion/models/__pycache__/unet_3d_condition.cpython-38.pyc differ diff --git a/video_diffusion/models/__pycache__/unet_3d_condition.cpython-39.pyc b/video_diffusion/models/__pycache__/unet_3d_condition.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1eb412c5387352ef38285375f413e732ec417539 Binary files /dev/null and b/video_diffusion/models/__pycache__/unet_3d_condition.cpython-39.pyc differ diff --git a/video_diffusion/models/__pycache__/unet_3d_condition_from_tuneavideo.cpython-310.pyc b/video_diffusion/models/__pycache__/unet_3d_condition_from_tuneavideo.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea2d20cb8020d7ca3a9fb551d594d98dc68f0efb Binary files /dev/null and b/video_diffusion/models/__pycache__/unet_3d_condition_from_tuneavideo.cpython-310.pyc differ diff --git a/video_diffusion/models/__pycache__/unet_3d_condition_from_tuneavideo.cpython-38.pyc b/video_diffusion/models/__pycache__/unet_3d_condition_from_tuneavideo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd72e407ece28042d8adb81ab683e34904a7b7d0 Binary files /dev/null and b/video_diffusion/models/__pycache__/unet_3d_condition_from_tuneavideo.cpython-38.pyc differ diff --git a/video_diffusion/models/__pycache__/unet_blocks.cpython-310.pyc b/video_diffusion/models/__pycache__/unet_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea050bfaa329fc35bb1830f8bf819d9381a4aace Binary files /dev/null and b/video_diffusion/models/__pycache__/unet_blocks.cpython-310.pyc differ diff --git a/video_diffusion/models/attention.py b/video_diffusion/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..16d8d00c6031bd08e5a2ee368963d4eb9ae4ece9 --- /dev/null +++ b/video_diffusion/models/attention.py @@ -0,0 +1,830 @@ +# code mostly taken from https://github.com/huggingface/diffusers +from dataclasses import dataclass +from typing import Optional, Callable +import math +import torch +from torch import nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin +from diffusers.models.cross_attention import CrossAttention +from diffusers.models.attention import FeedForward, AdaLayerNorm +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available + +from einops import rearrange +import numpy as np +import os +from PIL import Image +import glob + + +@dataclass +class SpatioTemporalTransformerModelOutput(BaseOutput): + """torch.FloatTensor of shape [batch x channel x frames x height x width]""" + + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class SpatioTemporalTransformerModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_temporal: bool = True, + model_config: dict = {}, + **transformer_kwargs, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + SpatioTemporalTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_temporal=use_temporal, + model_config=model_config, + **transformer_kwargs, + ) + for d in range(num_layers) + ] + ) + + # Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward( + self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True, **kwargs, + ): + # 1. Input + clip_length = None + is_video = hidden_states.ndim == 5 + if is_video: + clip_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = encoder_hidden_states.repeat_interleave(clip_length, 0) + else: + # To adapt to classifier-free guidance where encoder_hidden_states=2 + batch_size = hidden_states.shape[0]//encoder_hidden_states.shape[0] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(batch_size, 0) + *_, h, w = hidden_states.shape + residual = hidden_states + + # check resolution + + height = hidden_states.shape[-2] + width = hidden_states.shape[-1] + trajs = {} + trajs["traj"] = kwargs["trajs"]["traj{}".format(height)] + trajs["mask"] = kwargs["trajs"]["mask{}".format(height)] + # trajs["t"] = kwargs["t"] + trajs["old_qk"] = kwargs["old_qk"] + trajs['flatten_res'] = kwargs['flatten_res'] + trajs['height'] = height + trajs['width'] = width + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") # (bf) (hw) c + else: + hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, # [16, 4096, 320] + encoder_hidden_states=encoder_hidden_states, # ([1, 77, 768] + timestep=timestep, + clip_length=clip_length, + **trajs, + ) + + # 3. Output + if not self.use_linear_projection: + hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous() + + output = hidden_states + residual + if is_video: + output = rearrange(output, "(b f) c h w -> b c f h w", f=clip_length) + + if not return_dict: + return (output,) + + return SpatioTemporalTransformerModelOutput(sample=output) + +import copy +class SpatioTemporalTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_sparse_causal_attention: bool = True, + use_temporal:bool = False, + temporal_attention_position: str = "after_feedforward", + model_config: dict = {} + ): + super().__init__() + self.use_temporal=use_temporal + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_sparse_causal_attention = use_sparse_causal_attention + # For safety, freeze the model_config + self.model_config = copy.deepcopy(model_config) + if 'least_sc_channel' in model_config: + if dim< model_config['least_sc_channel']: + self.model_config['SparseCausalAttention_index'] = [] + + self.temporal_attention_position = temporal_attention_position + temporal_attention_positions = ["after_spatial", "after_cross", "after_feedforward"] + if temporal_attention_position not in temporal_attention_positions: + raise ValueError( + f"`temporal_attention_position` must be one of {temporal_attention_positions}" + ) + + # 1. Spatial-Attn + # spatial_attention = SparseCausalAttention if use_sparse_causal_attention else CrossAttention + # self.attn1 = spatial_attention( + # query_dim=dim, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # cross_attention_dim=cross_attention_dim if only_cross_attention else None, + # upcast_attention=upcast_attention, + # ) + # is a self-attention + # Fully + self.attn1 = FullyFrameAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + self.norm1 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + ) + else: + self.attn2 = None + self.norm2 = None + + # 3. Temporal-Attn + if use_temporal: + self.attn_temporal = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temporal.to_out[0].weight.data) # initialize as an identity function + self.norm_temporal = ( + AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + ) + else: + self.attn_temporal=None + + # efficient_attention_backward_cutlass is not implemented for large channels + self.use_xformers = (dim <= 320) or "3090" not in torch.cuda.get_device_name(0) + + # 4. Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + if use_memory_efficient_attention_xformers is True: + + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + else: + + pass + except Exception as e: + raise e + # self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temporal._use_memory_efficient_attention_xformers = ( + # use_memory_efficient_attention_xformers + # ), # FIXME: enabling this raises CUDA ERROR. Gotta dig in. + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + clip_length=None, + **kwargs, + ): + # 1. Self-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + # Update kwargs instead of creating a new dictionary + kwargs.update( + hidden_states=norm_hidden_states, + attention_mask=attention_mask, + ) + # kwargs = dict( + # hidden_states=norm_hidden_states, + # attention_mask=attention_mask, + # ) + if self.only_cross_attention: + kwargs.update(encoder_hidden_states=encoder_hidden_states) + if self.use_sparse_causal_attention: + kwargs.update(clip_length=clip_length) + if 'SparseCausalAttention_index' in self.model_config.keys(): + kwargs.update(SparseCausalAttention_index = self.model_config['SparseCausalAttention_index']) + + hidden_states = hidden_states + self.attn1(**kwargs) + + if clip_length is not None and self.temporal_attention_position == "after_spatial": + hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length) + # print('hidden_states after 1 self attention',hidden_states.shape) + if self.attn2 is not None: + # 2. Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + # print('norm_hidden_states',norm_hidden_states.shape) + hidden_states = ( + self.attn2( + norm_hidden_states, # [16, 4096, 320] + encoder_hidden_states=encoder_hidden_states, # [1, 77, 768] + attention_mask=attention_mask, + ) + + hidden_states + ) + + if clip_length is not None and self.temporal_attention_position == "after_cross" and self.attn_temporal is not None: + hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length) + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # if clip_length is not None and self.temporal_attention_position == "after_feedforward" and self.attn_temporal is not None: + # hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length) + + return hidden_states + + def apply_temporal_attention(self, hidden_states, timestep, clip_length): + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=clip_length) + norm_hidden_states = ( + self.norm_temporal(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm_temporal(hidden_states) + ) + hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + return hidden_states + + +def generate_dynamic_window_index(): + windows = [(0,5), (6,13), (14,21), (22,31), (32,39), (40,51), (52,58), (59,69)] + video_index = [] + + for start, end in windows: + for i in range(start, end + 1): + video_index.append(start) + + return video_index + +def generate_fix_anchor_frames(clip_length, interval): + # 创建锚点索引列表 + anchors = list(range(0, clip_length, interval)) + # 确保最后一个索引不超过clip_length + if anchors[-1] >= clip_length: + anchors[-1] = clip_length - 1 + # 生成每个锚点重复clip_length次的列表 + frame_index_list = [[anchor] * clip_length for anchor in anchors] + return frame_index_list + +def generate_random_anchor_frame_index(clip_length, interval): + # 计算每个间隔的起始帧索引 + interval_starts = torch.arange(0, clip_length, interval) + + # 初始化锚点帧列表 + frame_index_list = [] + + # 在每个间隔内随机选择一个帧索引 + for start in interval_starts: + # 确保随机索引不超过clip_length + end = min(start + interval, clip_length) + anchor_frame = torch.randint(start, end, (1,)).item() + + # 生成重复clip_length次的锚点帧索引列表 + frame_index = [anchor_frame] * clip_length + frame_index_list.append(frame_index) + + return frame_index_list + +class SparseCausalAttention(CrossAttention): + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + clip_length: int = None, + SparseCausalAttention_index: list = [-1, 'first'] #list = ['anchor_interval8',0] #list = [0] #list = [-1, 'first','dynamic'] + ): + if ( + self.added_kv_proj_dim is not None + or encoder_hidden_states is not None + or attention_mask is not None + ): + raise NotImplementedError + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + if clip_length is not None: + key = rearrange(key, "(b f) d c -> b f d c", f=clip_length) + value = rearrange(value, "(b f) d c -> b f d c", f=clip_length) + + + # *********************** Start of Spatial-temporal attention ********** + frame_index_list = [] + # print(f'SparseCausalAttention_index {str(SparseCausalAttention_index)}') + if len(SparseCausalAttention_index) > 0: + for index in SparseCausalAttention_index: + if isinstance(index, str): + if index == 'first': + frame_index = [0] * clip_length + if index == 'last': + frame_index = [clip_length-1] * clip_length + if (index == 'mid') or (index == 'middle'): + frame_index = [int(clip_length-1)//2] * clip_length + if index == "dynamic": + frame_index = generate_dynamic_window_index() + if index == "anchor_interval8": + frame_index = generate_fix_anchor_frames(clip_length,8) + + else: + assert isinstance(index, int), 'relative index must be int' + frame_index = torch.arange(clip_length) + index + frame_index = frame_index.clip(0, clip_length-1) + + if isinstance(frame_index[0], list): + frame_index_list = frame_index + else: + frame_index_list.append(frame_index) + # print("frame_index_list",frame_index_list) + # print("key before concat",key.shape) #[1, frame, 4096, 320] + key = torch.cat([ key[:, frame_index] for frame_index in frame_index_list + ], dim=2) + value = torch.cat([ value[:, frame_index] for frame_index in frame_index_list + ], dim=2) + + # *********************** End of Spatial-temporal attention ********** + key = rearrange(key, "b f d c -> (b f) d c", f=clip_length) + value = rearrange(value, "b f d c -> (b f) d c", f=clip_length) + + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention( + query, key, value, hidden_states.shape[1], dim, attention_mask + ) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + +class FullyFrameAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = heads + # self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + self.q = None + self.inject_q = None + self.k = None + self.inject_k = None + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + + return hidden_states + + def reshape_heads_to_batch_dim3(self, tensor): + batch_size1, batch_size2, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size1, batch_size2, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 3, 1, 2, 4) + return tensor + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, clip_length=None, inter_frame=False,**kwargs): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + # h = w = int(math.sqrt(sequence_length)) + h = kwargs['height'] + w = kwargs['width'] + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + + query = self.to_q(hidden_states) # (bf) x d(hw) x c + + self.q = query + if self.inject_q is not None: + query = self.inject_q + dim = query.shape[-1] + query_old = query.clone() + dim = query.shape[-1] + + # All frames + query = rearrange(query, "(b f) d c -> b (f d) c", f=clip_length) + + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + self.k = key + if self.inject_k is not None: + key = self.inject_k + key_old = key.clone() + value = self.to_v(encoder_hidden_states) + + if inter_frame: + key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)[:, [0, -1]] + value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)[:, [0, -1]] + key = rearrange(key, "b f d c -> b (f d) c",) + value = rearrange(value, "b f d c -> b (f d) c") + else: + # All frames + key = rearrange(key, "(b f) d c -> b (f d) c", f=clip_length) + value = rearrange(value, "(b f) d c -> b (f d) c", f=clip_length) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + + if [h,w] in kwargs['flatten_res']: + + hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=clip_length) + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if kwargs["old_qk"] == 1: + query = query_old + key = key_old + else: + query = hidden_states + key = hidden_states + value = hidden_states + + traj = kwargs["traj"] + traj = rearrange(traj, '(f n) l d -> f n l d', f=clip_length, n=sequence_length) + mask = rearrange(kwargs["mask"], '(f n) l -> f n l', f=clip_length, n=sequence_length) + mask = torch.cat([mask[:, :, 0].unsqueeze(-1), mask[:, :, -clip_length+1:]], dim=-1) + + # print('traj',traj.shape) + # print('mask',mask.shape) + + traj_key_sequence_inds = torch.cat([traj[:, :, 0, :].unsqueeze(-2), traj[:, :, -clip_length+1:, :]], dim=-2) + t_inds = traj_key_sequence_inds[:, :, :, 0] + x_inds = traj_key_sequence_inds[:, :, :, 1] + y_inds = traj_key_sequence_inds[:, :, :, 2] + + query_tempo = query.unsqueeze(-2) + _key = rearrange(key, '(b f) (h w) d -> b f h w d', b=int(batch_size/clip_length), f=clip_length, h=h, w=w) + _value = rearrange(value, '(b f) (h w) d -> b f h w d', b=int(batch_size/clip_length), f=clip_length, h=h, w=w) + key_tempo = _key[:, t_inds, x_inds, y_inds] + value_tempo = _value[:, t_inds, x_inds, y_inds] + key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d') + value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d') + + mask = rearrange(torch.stack([mask, mask]), 'b f n l -> (b f) n l') + mask = mask[:,None].repeat(1, self.heads, 1, 1).unsqueeze(-2) + attn_bias = torch.zeros_like(mask, dtype=key_tempo.dtype) # regular zeros_like + attn_bias[~mask] = -torch.inf + + # print('attn_bias',attn_bias.shape) + # print('query_tempo',query_tempo.shape) + # print('key_tempo',key_tempo.shape) + + # flow attention + query_tempo = self.reshape_heads_to_batch_dim3(query_tempo) + key_tempo = self.reshape_heads_to_batch_dim3(key_tempo) + value_tempo = self.reshape_heads_to_batch_dim3(value_tempo) + + attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt(query_tempo.size(-1)) + attn_bias + attn_matrix2 = F.softmax(attn_matrix2, dim=-1) + out = (attn_matrix2@value_tempo).squeeze(-2) + + hidden_states = rearrange(out,'(b f) k (h w) d -> b (f h w) (k d)', b=int(batch_size/clip_length), f=clip_length, h=h, w=w) + + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + # All frames + hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=clip_length) + return hidden_states + + diff --git a/video_diffusion/models/controlnet3d.py b/video_diffusion/models/controlnet3d.py new file mode 100644 index 0000000000000000000000000000000000000000..9f36fb7c40a7c14ca5ca4083728346c5001707c8 --- /dev/null +++ b/video_diffusion/models/controlnet3d.py @@ -0,0 +1,510 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import os +import json +import torch +from torch import nn +from torch.nn import functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers import ModelMixin +import copy +# from .unet_3d_blocks import ( +# CrossAttnDownBlock3D, +# DownBlock3D, +# UNetMidBlock3DCrossAttn, +# get_down_block, +# ) + +from .controlnet_unet_blocks import get_down_block +from .controlnet_unet_blocks import CrossAttnDownBlockPseudo3D as CrossAttnDownBlock3D +from .controlnet_unet_blocks import DownBlockPseudo3D as DownBlock3D +from .controlnet_unet_blocks import UNetMidBlockPseudo3DCrossAttn as UNetMidBlock3DCrossAttn + +from .resnet import InflatedConv3d + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockPseudo3D", + "CrossAttnDownBlockPseudo3D", + "CrossAttnDownBlockPseudo3D", + "DownBlockPseudo3D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + **kwargs + ): + super().__init__() + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = InflatedConv3d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + kwargs_copy=copy.deepcopy(kwargs) + # temporal_downsample_i = ((i >= (len(down_block_types)-self.temporal_downsample_time)) + # and (not is_final_block)) + # kwargs_copy.update({'temporal_downsample': temporal_downsample_i} ) + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_temporal=False, + model_config=kwargs_copy + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_temporal=False, + model_config=kwargs + ) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + sample += controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, control_temporal_idx=None, control_mid_temporal=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlockPseudo3D", + "CrossAttnDownBlockPseudo3D", + "CrossAttnDownBlockPseudo3D", + "DownBlockPseudo3D", + ] + # config["control_temporal_idx"] = control_temporal_idx + # config["control_mid_temporal"] = control_mid_temporal + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config) + + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + + state_dict = torch.load(model_file, map_location="cpu") + + for k, v in model.state_dict().items(): + if '_temporal.' in k: + if 'conv' in k: + state_dict.update({k: v}) + else: + copyk = k + copyk = copyk.replace('_temporal.', '1.') + state_dict.update({k: state_dict[copyk]}) + + # for key in state_dict.keys(): + # print(key) + # exit() + + + model.load_state_dict(state_dict) + + return model + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/video_diffusion/models/controlnet_attention.py b/video_diffusion/models/controlnet_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8e19e019bbefc6700315b5ede825a20bc45f8d92 --- /dev/null +++ b/video_diffusion/models/controlnet_attention.py @@ -0,0 +1,911 @@ +# code mostly taken from https://github.com/huggingface/diffusers +from dataclasses import dataclass +from typing import Optional, Callable + +import torch +from torch import nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin +from diffusers.models.attention import FeedForward, AdaLayerNorm +from diffusers.models.cross_attention import CrossAttention +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available + +from einops import rearrange +import numpy as np +import os +from PIL import Image +import glob + + +@dataclass +class SpatioTemporalTransformerModelOutput(BaseOutput): + """torch.FloatTensor of shape [batch x channel x frames x height x width]""" + + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class SpatioTemporalTransformerModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_temporal: bool = True, + model_config: dict = {}, + **transformer_kwargs, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + SpatioTemporalTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_temporal=use_temporal, + model_config=model_config, + **transformer_kwargs, + ) + for d in range(num_layers) + ] + ) + + # Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward( + self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True + ): + # 1. Input + clip_length = None + is_video = hidden_states.ndim == 5 + if is_video: + clip_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = encoder_hidden_states.repeat_interleave(clip_length, 0) + else: + # To adapt to classifier-free guidance where encoder_hidden_states=2 + batch_size = hidden_states.shape[0]//encoder_hidden_states.shape[0] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(batch_size, 0) + *_, h, w = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") # (bf) (hw) c + else: + hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, # [16, 4096, 320] + encoder_hidden_states=encoder_hidden_states, # ([1, 77, 768] + timestep=timestep, + clip_length=clip_length, + ) + + # 3. Output + if not self.use_linear_projection: + hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous() + + output = hidden_states + residual + if is_video: + output = rearrange(output, "(b f) c h w -> b c f h w", f=clip_length) + + if not return_dict: + return (output,) + + return SpatioTemporalTransformerModelOutput(sample=output) + +import copy +class SpatioTemporalTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_sparse_causal_attention: bool = True, + use_temporal:bool = False, + temporal_attention_position: str = "after_feedforward", + model_config: dict = {} + ): + super().__init__() + self.use_temporal=use_temporal + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_sparse_causal_attention = use_sparse_causal_attention + # For safety, freeze the model_config + self.model_config = copy.deepcopy(model_config) + if 'least_sc_channel' in model_config: + if dim< model_config['least_sc_channel']: + self.model_config['SparseCausalAttention_index'] = [] + + self.temporal_attention_position = temporal_attention_position + temporal_attention_positions = ["after_spatial", "after_cross", "after_feedforward"] + if temporal_attention_position not in temporal_attention_positions: + raise ValueError( + f"`temporal_attention_position` must be one of {temporal_attention_positions}" + ) + + # 1. Spatial-Attn + # spatial_attention = SparseCausalAttention if use_sparse_causal_attention else CrossAttention + # self.attn1 = spatial_attention( + # query_dim=dim, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # cross_attention_dim=cross_attention_dim if only_cross_attention else None, + # upcast_attention=upcast_attention, + # ) + # is a self-attention + # Fully + self.attn1 = IndividualAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + self.norm1 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + ) + else: + self.attn2 = None + self.norm2 = None + + # 3. Temporal-Attn + if use_temporal: + self.attn_temporal = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temporal.to_out[0].weight.data) # initialize as an identity function + self.norm_temporal = ( + AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + ) + else: + self.attn_temporal=None + + # efficient_attention_backward_cutlass is not implemented for large channels + self.use_xformers = (dim <= 320) or "3090" not in torch.cuda.get_device_name(0) + + # 4. Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool,attention_op: Optional[Callable] = None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + if use_memory_efficient_attention_xformers is True: + + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + else: + + pass + except Exception as e: + raise e + # self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temporal._use_memory_efficient_attention_xformers = ( + # use_memory_efficient_attention_xformers + # ), # FIXME: enabling this raises CUDA ERROR. Gotta dig in. + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + clip_length=None, + ): + # 1. Self-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + kwargs = dict( + hidden_states=norm_hidden_states, + attention_mask=attention_mask, + ) + if self.only_cross_attention: + kwargs.update(encoder_hidden_states=encoder_hidden_states) + if self.use_sparse_causal_attention: + kwargs.update(clip_length=clip_length) + if 'SparseCausalAttention_index' in self.model_config.keys(): + kwargs.update(SparseCausalAttention_index = self.model_config['SparseCausalAttention_index']) + + hidden_states = hidden_states + self.attn1(**kwargs) + + if clip_length is not None and self.temporal_attention_position == "after_spatial": + hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length) + # print('hidden_states after 1 self attention',hidden_states.shape) + if self.attn2 is not None: + # 2. Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + # print('norm_hidden_states',norm_hidden_states.shape) + hidden_states = ( + self.attn2( + norm_hidden_states, # [16, 4096, 320] + encoder_hidden_states=encoder_hidden_states, # [1, 77, 768] + attention_mask=attention_mask, + ) + + hidden_states + ) + + if clip_length is not None and self.temporal_attention_position == "after_cross" and self.attn_temporal is not None: + hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length) + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # if clip_length is not None and self.temporal_attention_position == "after_feedforward" and self.attn_temporal is not None: + # hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length) + + return hidden_states + + def apply_temporal_attention(self, hidden_states, timestep, clip_length): + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=clip_length) + norm_hidden_states = ( + self.norm_temporal(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm_temporal(hidden_states) + ) + hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + return hidden_states + +class IndividualAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, clip_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # (bf) x d(hw) x c + dim = query.shape[-1] + + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + curr_frame_index = torch.arange(clip_length) + + key = rearrange(key, "(b f) d c -> b f d c", f=clip_length) + + key = key[:, curr_frame_index] + key = rearrange(key, "b f d c -> (b f) d c") + + value = rearrange(value, "(b f) d c -> b f d c", f=clip_length) + + value = value[:, curr_frame_index] + value = rearrange(value, "b f d c -> (b f) d c") + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + clip_length: int = None, + SparseCausalAttention_index: list = [-1, 'first'] #list = ['anchor_interval8',0] #list = [0] #list = [-1, 'first','dynamic'] + ): + if ( + self.added_kv_proj_dim is not None + or encoder_hidden_states is not None + or attention_mask is not None + ): + raise NotImplementedError + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + if clip_length is not None: + key = rearrange(key, "(b f) d c -> b f d c", f=clip_length) + value = rearrange(value, "(b f) d c -> b f d c", f=clip_length) + + + # *********************** Start of Spatial-temporal attention ********** + frame_index_list = [] + # print(f'SparseCausalAttention_index {str(SparseCausalAttention_index)}') + if len(SparseCausalAttention_index) > 0: + for index in SparseCausalAttention_index: + if isinstance(index, str): + if index == 'first': + frame_index = [0] * clip_length + if index == 'last': + frame_index = [clip_length-1] * clip_length + if (index == 'mid') or (index == 'middle'): + frame_index = [int(clip_length-1)//2] * clip_length + if index == "dynamic": + frame_index = generate_dynamic_window_index() + if index == "anchor_interval8": + frame_index = generate_fix_anchor_frames(clip_length,8) + + else: + assert isinstance(index, int), 'relative index must be int' + frame_index = torch.arange(clip_length) + index + frame_index = frame_index.clip(0, clip_length-1) + + if isinstance(frame_index[0], list): + frame_index_list = frame_index + else: + frame_index_list.append(frame_index) + # print("frame_index_list",frame_index_list) + # print("key before concat",key.shape) #[1, frame, 4096, 320] + key = torch.cat([ key[:, frame_index] for frame_index in frame_index_list + ], dim=2) + value = torch.cat([ value[:, frame_index] for frame_index in frame_index_list + ], dim=2) + + # *********************** End of Spatial-temporal attention ********** + key = rearrange(key, "b f d c -> (b f) d c", f=clip_length) + value = rearrange(value, "b f d c -> (b f) d c", f=clip_length) + + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention( + query, key, value, hidden_states.shape[1], dim, attention_mask + ) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + +class FullyFrameAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = heads + # self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, clip_length=None, inter_frame=False): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # (bf) x d(hw) x c + dim = query.shape[-1] + + # All frames + query = rearrange(query, "(b f) d c -> b (f d) c", f=clip_length) + + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if inter_frame: + key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)[:, [0, -1]] + value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)[:, [0, -1]] + key = rearrange(key, "b f d c -> b (f d) c",) + value = rearrange(value, "b f d c -> b (f d) c") + else: + # All frames + key = rearrange(key, "(b f) d c -> b (f d) c", f=clip_length) + value = rearrange(value, "(b f) d c -> b (f d) c", f=clip_length) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + # All frames + hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=clip_length) + return hidden_states + diff --git a/video_diffusion/models/controlnet_unet_blocks.py b/video_diffusion/models/controlnet_unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6fcc3d3754e8afcb7bfd55406452afc13c1a8e --- /dev/null +++ b/video_diffusion/models/controlnet_unet_blocks.py @@ -0,0 +1,641 @@ +# code mostly taken from https://github.com/huggingface/diffusers +import torch +from torch import nn + +from .controlnet_attention import SpatioTemporalTransformerModel +from .resnet import DownsamplePseudo3D, ResnetBlockPseudo3D, UpsamplePseudo3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_temporal=True, + model_config: dict={} +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockPseudo3D": + return DownBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + model_config=model_config + ) + elif down_block_type == "CrossAttnDownBlockPseudo3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockPseudo3D") + return CrossAttnDownBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_temporal=use_temporal, + model_config=model_config + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_temporal=True, + model_config: dict={} +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockPseudo3D": + return UpBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + model_config=model_config + ) + elif up_block_type == "CrossAttnUpBlockPseudo3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockPseudo3D") + return CrossAttnUpBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_temporal=use_temporal, + model_config=model_config + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlockPseudo3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_temporal=False, + model_config: dict={} + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + # model_config=model_config + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + SpatioTemporalTransformerModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_temporal = use_temporal, + model_config=model_config + ) + ) + resnets.append( + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + # model_config=model_config + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + # TODO(Patrick, William) - attention_mask is currently not used. Implement once used + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_temporal=True, + model_config: dict={} + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm + # model_config=model_config + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + SpatioTemporalTransformerModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_temporal=use_temporal, + model_config=model_config + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + DownsamplePseudo3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + # model_config=model_config + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + model_config: dict={} + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm + # model_config=model_config + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + DownsamplePseudo3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + # model_config=model_config + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_temporal=True, + model_config: dict={}, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + self.model_config = model_config + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockPseudo3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm + # model_config=model_config + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + SpatioTemporalTransformerModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_temporal=use_temporal, + model_config=model_config + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [UpsamplePseudo3D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + # TODO(Patrick, William) - attention mask is not used + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + model_config: dict={}, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockPseudo3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + # model_config=model_config + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [UpsamplePseudo3D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states \ No newline at end of file diff --git a/video_diffusion/models/lora.py b/video_diffusion/models/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae65673d6b3012da4f551b8d61a2e9ceb38aa24 --- /dev/null +++ b/video_diffusion/models/lora.py @@ -0,0 +1,131 @@ +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +# from diffusers.utils +from diffusers.utils import deprecate, logging +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm +from diffusers.models.cross_attention import CrossAttention + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, stride=1): + super().__init__() + + if rank > min(in_features, out_features): + Warning(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}, reset to {min(in_features, out_features)//2}") + rank = min(in_features, out_features)//2 + + + self.down = nn.Conv1d(in_features, rank, bias=False, + kernel_size=3, + stride = stride, + padding=1,) + self.up = nn.Conv1d(rank, out_features, bias=False, + kernel_size=3, + padding=1,) + + nn.init.normal_(self.down.weight, std=1 / rank) + # nn.init.zeros_(self.down.bias.data) + + nn.init.zeros_(self.up.weight) + # nn.init.zeros_(self.up.bias.data) + if stride > 1: + self.skip = nn.AvgPool1d(kernel_size=3, stride=2, padding=1) + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + if hasattr(self, 'skip'): + hidden_states=self.skip(hidden_states) + return up_hidden_states.to(orig_dtype)+hidden_states + + +class LoRACrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + super().__init__() + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__( + self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + + + +class LoRAXFormersCrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim, rank=4): + super().__init__() + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__( + self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query).contiguous() + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states diff --git a/video_diffusion/models/resnet.py b/video_diffusion/models/resnet.py new file mode 100755 index 0000000000000000000000000000000000000000..c63855f2fff6f0af6a656d76eac8cd35dcbd8b25 --- /dev/null +++ b/video_diffusion/models/resnet.py @@ -0,0 +1,537 @@ +# code mostly taken from https://github.com/huggingface/diffusers +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy + +from einops import rearrange +from .lora import LoRALinearLayer, LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class PseudoConv3d(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, temporal_kernel_size=None, model_config: dict={}, temporal_downsample=False, **kwargs): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + **kwargs, + ) + if temporal_kernel_size is None: + temporal_kernel_size = kernel_size + + if temporal_downsample is True: + temporal_stride = 2 + else: + temporal_stride = 1 + + + if 'lora' in model_config.keys() : + self.conv_temporal = ( + LoRALinearLayer( + out_channels, + out_channels, + rank=model_config['lora'], + stride=temporal_stride + + ) + if kernel_size > 1 + else None + ) + else: + self.conv_temporal = ( + nn.Conv1d( + out_channels, + out_channels, + kernel_size=temporal_kernel_size, + padding=temporal_kernel_size // 2, + ) + if kernel_size > 1 + else None + ) + + if self.conv_temporal is not None: + nn.init.dirac_(self.conv_temporal.weight.data) # initialized to be identity + nn.init.zeros_(self.conv_temporal.bias.data) + + def forward(self, x): + b = x.shape[0] + + is_video = x.ndim == 5 + if is_video: + x = rearrange(x, "b c f h w -> (b f) c h w") + + x = super().forward(x) + + if is_video: + x = rearrange(x, "(b f) c h w -> b c f h w", b=b) + + if self.conv_temporal is None or not is_video: + return x + + *_, h, w = x.shape + + x = rearrange(x, "b c f h w -> (b h w) c f") + + x = self.conv_temporal(x) + + x = rearrange(x, "(b h w) c f -> b c f h w", h=h, w=w) + + return x + + +class UpsamplePseudo3D(nn.Module): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv", model_config: dict={}, **kwargs + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.model_config = copy.deepcopy(model_config) + + conv = None + if use_conv_transpose: + raise NotImplementedError + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + # Do NOT downsample in upsample block + td = False + + conv = PseudoConv3d(self.channels, self.out_channels, 3, padding=1, + model_config=model_config, temporal_downsample=td) + # conv = PseudoConv3d(self.channels, self.out_channels, 3, kwargs['lora'], padding=1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + b = hidden_states.shape[0] + is_video = hidden_states.ndim == 5 + if is_video: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + if is_video: + td = ('temporal_downsample' in self.model_config and self.model_config['temporal_downsample'] is True) + + + if td: + hidden_states = rearrange(hidden_states, " (b f) c h w -> b c h w f ", b=b) + t_b, t_c, t_h, t_w, t_f = hidden_states.shape + hidden_states = rearrange(hidden_states, " b c h w f -> (b c) (h w) f ", b=b) + + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="linear") + hidden_states = rearrange(hidden_states, " (b c) (h w) f -> (b f) c h w ", b=t_b, h=t_h) + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if is_video: + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class DownsamplePseudo3D(nn.Module): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, model_config: dict={}, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + self.model_config = copy.deepcopy(model_config) + # self.model_config = copy.deepcopy(model_config) + + if use_conv: + td = ('temporal_downsample' in model_config and model_config['temporal_downsample'] is True) + + conv = PseudoConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding, + model_config=model_config, temporal_downsample=td) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + if self.use_conv: + hidden_states = self.conv(hidden_states) + else: + b = hidden_states.shape[0] + is_video = hidden_states.ndim == 5 + if is_video: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = self.conv(hidden_states) + if is_video: + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b) + + return hidden_states + + +class ResnetBlockPseudo3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_in_shortcut=None, + up=False, + down=False, + model_config: dict={}, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + + self.conv1 = PseudoConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, model_config=model_config) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = PseudoConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, model_config=model_config) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = UpsamplePseudo3D(in_channels, use_conv=False, model_config=model_config) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = DownsamplePseudo3D(in_channels, use_conv=False, padding=1, name="op", model_config=model_config) + + self.use_in_shortcut = ( + self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = PseudoConv3d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, model_config=model_config + ) + # save features + self.out_layers_features = None + self.out_layers_inject_features = None + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + + if temb is not None and self.time_embedding_norm == "default": + is_video = hidden_states.ndim == 5 + if is_video: + b, c, f, h, w = hidden_states.shape + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + temb = temb.repeat_interleave(f, 0) + + hidden_states = hidden_states + temb + + if is_video: + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b) + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + is_video = hidden_states.ndim == 5 + if is_video: + b, c, f, h, w = hidden_states.shape + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + temb = temb.repeat_interleave(f, 0) + + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + if is_video: + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + # save features + self.out_layers_features = hidden_states + if self.out_layers_inject_features is not None: + hidden_states = self.out_layers_inject_features + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) + + +def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): + r"""Upsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output + + +def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): + r"""Downsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + return output + + +def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = tensor.shape + kernel_h, kernel_w = kernel.shape + + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(tensor.device) # Move back to mps if necessary + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) \ No newline at end of file diff --git a/video_diffusion/models/unet_3d_blocks.py b/video_diffusion/models/unet_3d_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..e697d1a8c7a0c178bb87310d9df99b620d1062d3 --- /dev/null +++ b/video_diffusion/models/unet_3d_blocks.py @@ -0,0 +1,642 @@ +# code mostly taken from https://github.com/huggingface/diffusers +import torch +from torch import nn + +from .attention import SpatioTemporalTransformerModel +from .resnet import DownsamplePseudo3D, ResnetBlockPseudo3D, UpsamplePseudo3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_temporal=True, + model_config: dict={} +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockPseudo3D": + return DownBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + model_config=model_config + ) + elif down_block_type == "CrossAttnDownBlockPseudo3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockPseudo3D") + return CrossAttnDownBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_temporal=use_temporal, + model_config=model_config + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_temporal=True, + model_config: dict={} +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockPseudo3D": + return UpBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + model_config=model_config + ) + elif up_block_type == "CrossAttnUpBlockPseudo3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockPseudo3D") + return CrossAttnUpBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_temporal=use_temporal, + model_config=model_config + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlockPseudo3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_temporal=False, + model_config: dict={} + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + # model_config=model_config + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + SpatioTemporalTransformerModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_temporal = use_temporal, + model_config=model_config + ) + ) + resnets.append( + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + # model_config=model_config + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, **kwargs ): + # TODO(Patrick, William) - attention_mask is currently not used. Implement once used + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,**kwargs).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_temporal=True, + model_config: dict={} + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm + # model_config=model_config + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + SpatioTemporalTransformerModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_temporal=use_temporal, + model_config=model_config + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + DownsamplePseudo3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + # model_config=model_config + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, **kwargs): + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, **kwargs).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + model_config: dict={} + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm + # model_config=model_config + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + DownsamplePseudo3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + # model_config=model_config + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_temporal=True, + model_config: dict={}, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + self.model_config = model_config + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockPseudo3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm + # model_config=model_config + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + SpatioTemporalTransformerModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_temporal=use_temporal, + model_config=model_config + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [UpsamplePseudo3D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + **kwargs, + ): + # TODO(Patrick, William) - attention mask is not used + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,**kwargs).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + model_config: dict={}, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockPseudo3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + # model_config=model_config + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [UpsamplePseudo3D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states \ No newline at end of file diff --git a/video_diffusion/models/unet_3d_condition.py b/video_diffusion/models/unet_3d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..59a92a455280ff641af39df3ac29899c5480f824 --- /dev/null +++ b/video_diffusion/models/unet_3d_condition.py @@ -0,0 +1,555 @@ +# code mostly taken from https://github.com/huggingface/diffusers +import os +import glob +import json +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +import copy + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from .unet_3d_blocks import ( + CrossAttnDownBlockPseudo3D, + CrossAttnUpBlockPseudo3D, + DownBlockPseudo3D, + UNetMidBlockPseudo3DCrossAttn, + UpBlockPseudo3D, + get_down_block, + get_up_block, +) +from .resnet import PseudoConv3d + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetPseudo3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNetPseudo3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockPseudo3D", + "CrossAttnDownBlockPseudo3D", + "CrossAttnDownBlockPseudo3D", + "DownBlockPseudo3D", + ), + mid_block_type: str = "UNetMidBlockPseudo3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlockPseudo3D", + "CrossAttnUpBlockPseudo3D", + "CrossAttnUpBlockPseudo3D", + "CrossAttnUpBlockPseudo3D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + **kwargs + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + if 'temporal_downsample' in kwargs and kwargs['temporal_downsample'] is True: + kwargs['temporal_downsample_time'] = 3 + self.temporal_downsample_time = kwargs.get('temporal_downsample_time', 0) + + # input + self.conv_in = PseudoConv3d(in_channels, block_out_channels[0], + kernel_size=3, padding=(1, 1), model_config=kwargs) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + kwargs_copy=copy.deepcopy(kwargs) + temporal_downsample_i = ((i >= (len(down_block_types)-self.temporal_downsample_time)) + and (not is_final_block)) + kwargs_copy.update({'temporal_downsample': temporal_downsample_i} ) + + # kwargs_copy.update({'SparseCausalAttention_index': temporal_downsample_i} ) + if temporal_downsample_i: + print(f'Initialize model temporal downsample at layer {i}') + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_temporal=True, + model_config=kwargs_copy + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlockPseudo3DCrossAttn": + self.mid_block = UNetMidBlockPseudo3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_temporal=True, ##not sure check org fatezero + model_config=kwargs + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + kwargs_copy=copy.deepcopy(kwargs) + kwargs_copy.update({'temporal_downsample': + i < (self.temporal_downsample_time-1)}) + if i < (self.temporal_downsample_time-1): + print(f'Initialize model temporal updample at layer {i}') + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_temporal=True, + model_config=kwargs_copy + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + self.conv_out = PseudoConv3d(block_out_channels[0], out_channels, + kernel_size=3, padding=1, model_config=kwargs) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = ( + num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + ) + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance( + module, + (CrossAttnDownBlockPseudo3D, DownBlockPseudo3D, CrossAttnUpBlockPseudo3D, UpBlockPseudo3D), + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, # None + attention_mask: Optional[torch.Tensor] = None, # None + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + **kwargs, + ) -> Union[UNetPseudo3DConditionOutput, Tuple]: + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: # None + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: # False + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + new_down_block_res_samples += (down_block_res_sample + down_block_additional_residual,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + **kwargs, + ) + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # for i in down_block_res_samples: print(i.shape) + # torch.Size([1, 320, 16, 64, 64]) + # torch.Size([1, 320, 16, 64, 64]) + # torch.Size([1, 320, 16, 64, 64]) + # torch.Size([1, 320, 8, 32, 32]) + # torch.Size([1, 640, 8, 32, 32]) + # torch.Size([1, 640, 8, 32, 32]) + # torch.Size([1, 640, 4, 16, 16]) + # torch.Size([1, 1280, 4, 16, 16]) + # torch.Size([1, 1280, 4, 16, 16]) + # torch.Size([1, 1280, 2, 8, 8]) + # torch.Size([1, 1280, 2, 8, 8]) + # torch.Size([1, 1280, 2, 8, 8]) + + # 5. up + # sample torch.Size([1, 1280, 37, 32, 32]) + # sample torch.Size([1, 640, 37, 64, 64]) + # sample torch.Size([1, 320, 37, 64, 64]) + + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + **kwargs, + ) + # if up_block_additional_residual is not None and sample.shape[-1] == 32: + # sample = sample + up_block_additional_residual.unsqueeze(0) ### dift embedding for key point + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + return UNetPseudo3DConditionOutput(sample=sample) + + @classmethod + def from_2d_model(cls, model_path, model_config): + config_path = os.path.join(model_path, "config.json") + if not os.path.isfile(config_path): + raise RuntimeError(f"{config_path} does not exist") + with open(config_path, "r") as f: + config = json.load(f) + + config.pop("_class_name") + config.pop("_diffusers_version") + + block_replacer = { + "CrossAttnDownBlock2D": "CrossAttnDownBlockPseudo3D", + "DownBlock2D": "DownBlockPseudo3D", + "UpBlock2D": "UpBlockPseudo3D", + "CrossAttnUpBlock2D": "CrossAttnUpBlockPseudo3D", + } + + def convert_2d_to_3d_block(block): + return block_replacer[block] if block in block_replacer else block + + config["down_block_types"] = [ + convert_2d_to_3d_block(block) for block in config["down_block_types"] + ] + config["up_block_types"] = [convert_2d_to_3d_block(block) for block in config["up_block_types"]] + if model_config is not None: + config.update(model_config) + + model = cls(**config) + + state_dict_path_condidates = glob.glob(os.path.join(model_path, "*.bin")) + if state_dict_path_condidates: + state_dict = torch.load(state_dict_path_condidates[0], map_location="cpu") + model.load_2d_state_dict(state_dict=state_dict) + + return model + + def load_2d_state_dict(self, state_dict, **kwargs): + state_dict_3d = self.state_dict() + + + # for k,v in state_dict_3d.items(): + # print("new 3d model key:",k) + + # for k,v in state_dict.items(): + # print("org 2d model key:",k) + # exit() + for k, v in state_dict.items(): + if k not in state_dict_3d: + raise KeyError(f"2d state_dict key {k} does not exist in 3d model") + elif v.shape != state_dict_3d[k].shape: + raise ValueError(f"state_dict shape mismatch, 2d {v.shape}, 3d {state_dict_3d[k].shape}") + + for k, v in state_dict_3d.items(): + if "_temporal" in k: + continue + if k not in state_dict: + raise KeyError(f"3d state_dict key {k} does not exist in 2d model") + ### choice 1, init temporal attention with spatial attention weight + ''' + for k, v in state_dict_3d.items(): + if k not in state_dict: + if "_temporal" in k: + # org random init temporal attention + #continue + #print("init temporal attention with sd spatial attention weight") + if 'conv' in k: + ## may be continue + #state_dict_3d.update({k: v}) + continue + else: + copyk = k + copyk = copyk.replace('_temporal.', '1.') + state_dict_3d.update({k: state_dict[copyk]}) + + else: + raise KeyError(f"3d state_dict key {k} does not exist in 2d model") + ''' + #### end of choice 1############ + state_dict_3d.update(state_dict) + self.load_state_dict(state_dict_3d, **kwargs) \ No newline at end of file diff --git a/video_diffusion/pipelines/__pycache__/ddim_spatial_temporal.cpython-310.pyc b/video_diffusion/pipelines/__pycache__/ddim_spatial_temporal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e8f17a60f626de179c873870bcdebdb08cb4135 Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/ddim_spatial_temporal.cpython-310.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/ddim_spatial_temporal.cpython-38.pyc b/video_diffusion/pipelines/__pycache__/ddim_spatial_temporal.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f1ea620213e105dfb022ae72ab6722aea07a8d9 Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/ddim_spatial_temporal.cpython-38.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/ddim_spatial_temporal.cpython-39.pyc b/video_diffusion/pipelines/__pycache__/ddim_spatial_temporal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f87cdf376e41c660fd44b83b4eb66f2d0d7921e Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/ddim_spatial_temporal.cpython-39.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/p2p_ddim_spatial_temporal.cpython-310.pyc b/video_diffusion/pipelines/__pycache__/p2p_ddim_spatial_temporal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e93fcca2991c8df4a81b74cfd3a778adc8a8478a Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/p2p_ddim_spatial_temporal.cpython-310.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/p2p_ddim_spatial_temporal.cpython-38.pyc b/video_diffusion/pipelines/__pycache__/p2p_ddim_spatial_temporal.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4656eb78695ec34884d6684bbae7883d5196b35a Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/p2p_ddim_spatial_temporal.cpython-38.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/p2p_validation_loop.cpython-310.pyc b/video_diffusion/pipelines/__pycache__/p2p_validation_loop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc0536c26c731f62f8944959d5cccffedd7c4caf Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/p2p_validation_loop.cpython-310.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/p2p_validation_loop.cpython-38.pyc b/video_diffusion/pipelines/__pycache__/p2p_validation_loop.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca2c8f9c809ef234dc930d060eaabcf0037fa2a1 Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/p2p_validation_loop.cpython-38.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/p2p_validation_loop.cpython-39.pyc b/video_diffusion/pipelines/__pycache__/p2p_validation_loop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0706441e4fc29e3d09dd461bb12b8d55cc86771 Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/p2p_validation_loop.cpython-39.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/stable_diffusion.cpython-310.pyc b/video_diffusion/pipelines/__pycache__/stable_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c450ef2ebfee34482441abeb5470a522e313da11 Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/stable_diffusion.cpython-310.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/stable_diffusion.cpython-38.pyc b/video_diffusion/pipelines/__pycache__/stable_diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22c24f09b4b7de75a9715ac64ddf3bc2d58cb3e6 Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/stable_diffusion.cpython-38.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/stable_diffusion.cpython-39.pyc b/video_diffusion/pipelines/__pycache__/stable_diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be729f2a17b7159beebace8a0064352e1294dd3b Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/stable_diffusion.cpython-39.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/validation_loop.cpython-310.pyc b/video_diffusion/pipelines/__pycache__/validation_loop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23cde8ce8b90bbabecfabe7c813000374a7082c6 Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/validation_loop.cpython-310.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/validation_loop.cpython-38.pyc b/video_diffusion/pipelines/__pycache__/validation_loop.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c205d80bee168e1122f1f3dbe27e43f7abe6ca41 Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/validation_loop.cpython-38.pyc differ diff --git a/video_diffusion/pipelines/__pycache__/validation_loop.cpython-39.pyc b/video_diffusion/pipelines/__pycache__/validation_loop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afe95896b853390e0797c520f15607c830783141 Binary files /dev/null and b/video_diffusion/pipelines/__pycache__/validation_loop.cpython-39.pyc differ diff --git a/video_diffusion/pipelines/ddim_spatial_temporal.py b/video_diffusion/pipelines/ddim_spatial_temporal.py new file mode 100755 index 0000000000000000000000000000000000000000..4469c5e5d565df30565371ebc8d6e964518eb073 --- /dev/null +++ b/video_diffusion/pipelines/ddim_spatial_temporal.py @@ -0,0 +1,833 @@ +# code mostly taken from https://github.com/huggingface/diffusers +import inspect +from typing import Callable, List, Optional, Union +import PIL +import torch +import numpy as np +from einops import rearrange +from tqdm import tqdm + +from diffusers.utils import deprecate, logging +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + + +from .stable_diffusion import SpatioTemporalStableDiffusionPipeline +from diffusers.models import AutoencoderKL +from transformers import CLIPTextModel, CLIPTokenizer + +import torch.nn.functional as F +from omegaconf import OmegaConf +from video_diffusion.prompt_attention.attention_register import register_attention_control +from video_diffusion.prompt_attention.attention_util import ModulatedAttentionControl,ModulatedAttention_ControlEdit,Attention_Record_Processor +from video_diffusion.prompt_attention import attention_util +from video_diffusion.prompt_attention.sd_study_utils import * +from video_diffusion.prompt_attention.attention_store import AttentionStore +from video_diffusion.common.image_util import save_gif_mp4_folder_type + +from PIL import Image +from einops import rearrange +from ..models.controlnet3d import ControlNetModel +from ..models.unet_3d_condition import UNetPseudo3DConditionModel + +from diffusers.schedulers import ( + DDIMScheduler, + DDIMInverseScheduler, +) +import os +import nltk +nltk.download('punkt') + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DDIMSpatioTemporalStableDiffusionPipeline(SpatioTemporalStableDiffusionPipeline): + r""" + Pipeline for text-to-video generation using Spatio-Temporal Stable Diffusion. + """ + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNetPseudo3DConditionModel, + controlnet: ControlNetModel, + scheduler: DDIMScheduler, + inverse_scheduler: DDIMInverseScheduler, + disk_store: bool=False, + logdir=None, + ): + super().__init__(vae, text_encoder, tokenizer, unet, controlnet, scheduler,inverse_scheduler) + self.store_controller = attention_util.AttentionStore(disk_store=disk_store) + self.logdir=logdir + + r""" + Pipeline for text-to-video generation using Spatio-Temporal Stable Diffusion. + """ + + + def check_inputs(self, prompt, height, width, callback_steps, strength=None): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if strength is not None: + if strength <= 0 or strength > 1: + raise ValueError(f"The value of strength should in (0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + + @torch.no_grad() + def prepare_source_latents(self, image, batch_size, num_images_per_prompt, + # dtype, device, + text_embeddings, + generator=None): + + # Not sure if image need to change device and type + # image = image.to(device=device, dtype=dtype) + print("generator is list:",isinstance(generator, list)) + batch_size = batch_size * num_images_per_prompt + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + ## org is + #init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = self.vae.encode(image).latent_dist.mean + init_latents = 0.18215 * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + # get latents + init_latents_bcfhw = rearrange(init_latents, "(b f) c h w -> b c f h w", b=batch_size) + return init_latents_bcfhw + + + def prepare_latents_ddim_inverted(self, image, batch_size, + source_prompt, + do_classifier_free_guidance, + control = None, + controlnet_conditioning_scale=None, + use_pnp=None, + **kwargs, + ): + weight_dtype = image.dtype + device = self._execution_device + print('device',device) + timesteps = self.scheduler.timesteps + saved_features0 = [] + saved_features1 = [] + saved_features2 = [] + saved_q4 = [] + saved_k4 = [] + saved_q5 = [] + saved_k5 = [] + saved_q6 = [] + saved_k6 = [] + saved_q7 = [] + saved_k7 = [] + saved_q8 = [] + saved_k8 = [] + saved_q9 = [] + saved_k9 = [] + #ddim inverse + num_inverse_steps = 50 + self.inverse_scheduler.set_timesteps(num_inverse_steps, device=device) + inverse_timesteps, num_inverse_steps = self.get_inverse_timesteps(num_inverse_steps, 1, device) + num_warmup_steps = len(inverse_timesteps) - num_inverse_steps * self.inverse_scheduler.order + + #============ddim inversion==========* + prompt_embeds = self._encode_prompt( + source_prompt, + device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=None, + ) + + latents = self.prepare_video_latents(image, batch_size, self.unet.dtype, device) + + bz, c, clip_length, downsample_height, downsample_width = latents.shape + del self.store_controller + self.store_controller = attention_util.AttentionStore() + attention_maps_list = [] + self_attention_maps_list = [] + cond_embeddings_list = [] + + editor = Attention_Record_Processor(additional_attention_store=self.store_controller) + attention_util.register_attention_control(self, editor, prompt_embeds, clip_length,downsample_height,downsample_width,ddim_inversion=True) + + + with self.progress_bar(total=num_inverse_steps-1) as progress_bar: + for i, t in enumerate(inverse_timesteps[1:]): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + + down_block_res_samples, mid_block_res_sample = self.controlnet(latent_model_input, t, encoder_hidden_states=prompt_embeds,controlnet_cond=control,return_dict=False) + down_block_res_samples = [ + down_block_res_sample * controlnet_conditioning_scale + for down_block_res_sample in down_block_res_samples + ] + mid_block_res_sample *= controlnet_conditioning_scale + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + **kwargs, + ).sample + if use_pnp and t.cpu() in timesteps: + saved_features0.append(self.unet.up_blocks[1].resnets[0].out_layers_features.cpu()) + saved_features1.append(self.unet.up_blocks[1].resnets[1].out_layers_features.cpu()) + saved_features2.append(self.unet.up_blocks[2].resnets[0].out_layers_features.cpu()) + saved_q4.append(self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.q.cpu()) + saved_k4.append(self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.k.cpu()) + saved_q5.append(self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.q.cpu()) + saved_k5.append(self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.k.cpu()) + saved_q6.append(self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.q.cpu()) + saved_k6.append(self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.k.cpu()) + saved_q7.append(self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.q.cpu()) + saved_k7.append(self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.k.cpu()) + saved_q8.append(self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.q.cpu()) + saved_k8.append(self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.k.cpu()) + saved_q9.append(self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.q.cpu()) + saved_k9.append(self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.k.cpu()) + + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + 1 * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample.to(dtype=weight_dtype) + if i == len(inverse_timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0): + progress_bar.update() + if use_pnp: + saved_features0.reverse() + saved_features1.reverse() + saved_features2.reverse() + saved_q4.reverse() + saved_k4.reverse() + saved_q5.reverse() + saved_k5.reverse() + saved_q6.reverse() + saved_k6.reverse() + saved_q7.reverse() + saved_k7.reverse() + saved_q8.reverse() + saved_k8.reverse() + saved_q9.reverse() + saved_k9.reverse() + + attn_inversion_dict = { + 'features0': saved_features0, 'features1': saved_features1, 'features2': saved_features2, + 'q4': saved_q4,'k4': saved_k4,'q5': saved_q5,'k5': saved_k5,'q6': saved_q6,'k6': saved_k6, + 'q7': saved_q7,'k7': saved_k7,'q8': saved_q8,'k8': saved_k8,'q9': saved_q9,'k9': saved_k9 + } + else: + attn_inversion_dict = None + ''' + inv_self_avg_dict={} + inv_cross_avg_dict={} + element_name = 'attn' + attn_size = 32 + for element_name in ['attn']: + inv_self_avg_dict[element_name]={} + inv_cross_avg_dict[element_name]={} + + self_attn_avg = editor.aggregate_attention(from_where=("up", "down", "mid"), + res=attn_size,is_cross=False) + + cross_attn_avg = editor.aggregate_attention(from_where=("up", "down", "mid"), + res=attn_size,is_cross=True) + + print('self_attn_avg',self_attn_avg.shape) + print('cross_attn_avg', cross_attn_avg.shape) + inv_self_avg_dict[element_name][attn_size]=self_attn_avg + inv_cross_avg_dict[element_name][attn_size]=cross_attn_avg + + os.makedirs(os.path.join(self.logdir, "attn_inv"), exist_ok=True) + os.makedirs(os.path.join(self.logdir, "sd_study"), exist_ok=True) + with open(os.path.join(self.logdir, + "attn_inv/inv_self_avg_dict.pkl"), + 'wb') as f: + pkl.dump(inv_self_avg_dict, f) + + with open(os.path.join(self.logdir, + "attn_inv/inv_cross_avg_dict.pkl"), + 'wb') as f: + pkl.dump(inv_cross_avg_dict, f) + + num_segments=3 + draw_pca(inv_self_avg_dict, resolution=32, dict_key='attn', + save_path=os.path.join(self.logdir, 'sd_study'), + special_name='inv_self') + + run_clusters(inv_self_avg_dict, resolution=32, dict_key='attn', + save_path=os.path.join(self.logdir, 'sd_study'), + special_name='inv_self',num_segments=num_segments) + + cross_attn_visualization = attention_util.show_cross_attention_plus_org_img(self.tokenizer, source_prompt, + image, editor, 32, ["up", "down", "mid"], save_path= os.path.join(self.logdir,'sd_study'),attention_maps=cross_attn_avg) + + + dict_key='attn' + special_name='inv_self' + resolution = 32 + threshold=0.1 + + tokenized_prompt = nltk.word_tokenize(source_prompt) + nouns = [(i, word) for (i, (word, pos)) in enumerate(nltk.pos_tag(tokenized_prompt)) if pos[:2] == 'NN'] + print(nouns) + + npy_name=f'cluster_{dict_key}_{resolution}_{special_name}.npy' + save_path=os.path.join(self.logdir, 'sd_study') + + abs_filename=os.path.join(self.logdir, "attn_inv", f"inv_cross_avg_dict.pkl") + inv_cross_avg_dict=read_pkl(abs_filename) + + video_cross_attention = inv_cross_avg_dict['attn'][32] + video_clusters=np.load(os.path.join(save_path, npy_name)) + + t = video_clusters.shape[0] + for i in range(t): + clusters = video_clusters[i] + cross_attention = video_cross_attention[i] + c2noun, c2mask = cluster2noun_(clusters, threshold, num_segments, nouns,cross_attention) + print('c2noun',c2noun) + merged_mask={} + for index in range(len(c2noun)): + # mask_ = merged_mask[class_name] + item=c2noun[index] + mask_ = c2mask[index] + mask_ = torch.from_numpy(mask_) + mask_ = F.interpolate(mask_.float().unsqueeze(0).unsqueeze(0), size=512, mode='nearest').round().bool().squeeze(0).squeeze(0) + + output_name = os.path.join(f"{save_path}", + f"frame_{i}_{item}_{index}.png") + + save_mask(mask_, output_name) + ''' + return latents, attn_inversion_dict + + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def get_inverse_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + # safety for t_start overflow to prevent empty timsteps slice + if t_start == 0: + return self.inverse_scheduler.timesteps, num_inference_steps + timesteps = self.inverse_scheduler.timesteps[:-t_start] + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + batch_size, + num_channels_latents, + frames, + height, + width, + dtype, + device, + generator, + latents=None, + ): + print("self.vae_scale_factor",self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to( + device + ) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_video_latents(self, frames, batch_size, dtype, device, generator=None): + if not isinstance(frames, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(frames)}" + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(frames[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(frames).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + latents = rearrange(latents, "(b f) c h w ->b c f h w", b=batch_size) + + return latents + + def clean_features(self): + self.unet.up_blocks[1].resnets[0].out_layers_inject_features = None + self.unet.up_blocks[1].resnets[1].out_layers_inject_features = None + self.unet.up_blocks[2].resnets[0].out_layers_inject_features = None + self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_q = None + self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_k = None + self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_q = None + self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_k = None + self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_q = None + self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_k = None + self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_q = None + self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_k = None + self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_q = None + self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_k = None + self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_q = None + self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_k = None + + def _get_attention_type(self): + sub_nets = self.unet.named_children() + for net in sub_nets: + if hasattr(net[1], 'children'): + for net in net[1].named_children(): + if hasattr(net[1], 'children'): + for net in net[1].named_children(): + if net[1].__class__.__name__ == "SpatioTemporalTransformerModel": + for net in net[1].named_children(): + if hasattr(net[1], 'children'): + for net in net[1].named_children(): + if net[1].__class__.__name__ == "SpatioTemporalTransformerBlock": + for net in net[1].named_children(): + if net[1].__class__.__name__ == "SparseCausalAttention": + attention_type = "SparseCausalAttention" + elif net[1].__class__.__name__ == "FullyFrameAttention": + attention_type = "FullyFrameAttention" + #print("attention_type:",attention_type) + return attention_type + + def _prepare_attention_layout(self,bsz,height,width,layouts,prompts,clip_length,attention_type,device): + ## current layouts f s c h w + ## org layouts s c h w + + #print("prompt:",prompts) + # sp_sz =self.unet.sample_size + sp_sz = height*width + frames, seg_cls, c, h ,w = layouts.shape + text_input = self.tokenizer(prompts, padding="max_length", return_length=True, return_overflowing_tokens=False, + max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") + cond_embeddings = self.text_encoder(text_input.input_ids.to(device))[0] + + uncond_input = self.tokenizer([""]*bsz, padding="max_length", max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="pt") + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] + + for i in range(1,len(prompts)): + wlen = text_input['length'][i] - 2 + widx = text_input['input_ids'][i][1:1+wlen] + for j in range(77): + if (text_input['input_ids'][0][j:j+wlen] == widx).sum() == wlen: + break + + ########################### + ###### prep for sreg ###### + ########################### + + sreg_maps = {} + reg_sizes = {} + reg_sizes_c = {} + + device = layouts.device + + frame_index_pre = torch.arange(frames)+(-1) + frame_index_pre = frame_index_pre.clip(0, frames-1) + + for r in range(4): + layouts_s_frames = [] + if attention_type == "SparseCausalAttention": + layouts_s_sparse_attn = [] + + h = int(height/np.power(2,r)) + w= int(width/np.power(2,r)) + #layouts torch.Size([70, 2, 1, 64, 64]) + # layouts_interpolate = F.interpolate(layouts.squeeze(2), (res, res), mode='nearest').unsqueeze(2) + layouts_interpolate = F.interpolate(layouts.squeeze(2), (h, w), mode='nearest').unsqueeze(2) + layouts_interpolate = layouts_interpolate.view(frames,seg_cls,1,-1) ## frames,seg_cls,1,res^2 + + ### implementation of sparse casual attn and fully frame attn + for i in range(frames): + #layouts_f = layouts[i] + + layouts_s = layouts_interpolate[i] + + if attention_type == "SparseCausalAttention": + + ### prepare for SparseCausalAttention query, key/value + query= layouts_s + query = query.view(query.size(0),-1,1).to(device) ### segcls,res^2,1 #[cls, 4096, 1] + + ### key should be segcls,1,2xres^2 + key = torch.cat((layouts_interpolate[0],layouts_interpolate[frame_index_pre[i]]),dim=-1).to(device) + #([cls, 1, 8192]) + + layouts_s_cross_frame_attn= (query * key).sum(0).unsqueeze(0).repeat(bsz,1,1) ## 1,4096,8192 + + layouts_s_sparse_attn.append(layouts_s_cross_frame_attn) + + layouts_s = (layouts_s.view(layouts_s.size(0),1,-1)*layouts_s.view(layouts_s.size(0),-1,1)).sum(0).unsqueeze(0).repeat(bsz,1,1) + + layouts_s_frames.append(layouts_s) + + + layouts_s_frames = torch.stack(layouts_s_frames,dim=0) + if attention_type == "SparseCausalAttention": + layouts_s_sparse_attn = torch.stack(layouts_s_sparse_attn,dim=0) + sreg_maps[h*w] = layouts_s_sparse_attn + reg_sizes[h*w] = 1-1.*layouts_s_frames.sum(-1, keepdim=True)/(np.power(clip_length, 2)) + reg_sizes_c[h*w] = 1-1.*layouts_s_frames.sum(-1, keepdim=True)/(np.power(clip_length, 2)) + #### code for check error##### + # num_nonzero = torch.count_nonzero(layouts_s_frames) + # print("num_nonzero",num_nonzero) + # print("layouts_s_frames",layouts_s_frames.shape) + # print("layouts_s_frames",layouts_s_frames) + # print("reg_size final shape:", (1-1.*layouts_s_frames.sum(-1, keepdim=True)/(np.power(res, 2))).shape) + # print("reg_size", (1-1.*layouts_s_frames.sum(-1, keepdim=True)/(np.power(res, 2)))) + #### code for check error##### + + + #print("layouts_s",layouts_s.shape) + + #print("layouts_s.view(layouts_s.size(0),-1,1)",*layouts_s.view(layouts_s.size(0),-1,1).shape) + + if attention_type == "FullyFrameAttention": + layouts_s= rearrange(layouts_interpolate,"f s c res -> s c (f res)") + if r==0: + layout_s = None + reg_sizes[h*w] = None + sreg_maps[h*w] = None + reg_sizes_c[h*w] = None + else: + layouts_s = (layouts_s*layouts_s.view(layouts_s.size(0),-1,1)).sum(0).unsqueeze(0).repeat(bsz,1,1).to(torch.float16) + sreg_maps[h*w] = layouts_s + reg_sizes[h*w] = 1-1.*layouts_s.sum(-1, keepdim=True)/((h*clip_length)*(w*clip_length)) + reg_sizes_c[h*w] = 1-1.*layouts_s_frames.sum(-1, keepdim=True)/(h*w) + #print("layouts_s",layouts_s.shape) + # if res == 64: + # reg_sizes[np.power(res, 2)] = None + # else: + # reg_sizes[np.power(res, 2)] = 1-1.*layouts_s.sum(-1, keepdim=True)/(np.power(res*clip_length, 2)) + # #sreg_maps[np.power(res, 2)] = layouts_s_frames + # sreg_maps[np.power(res, 2)] = layouts_s + # reg_sizes_c[np.power(res, 2)] = 1-1.*layouts_s_frames.sum(-1, keepdim=True)/(np.power(res, 2)) + + + ########################### + ###### prep for creg ###### + ########################### + pww_maps = torch.zeros(frames, 1, 77, height, width).to(device) + for i in range(1,len(prompts)): + wlen = text_input['length'][i] - 2 + widx = text_input['input_ids'][i][1:1+wlen] + for j in range(77): + if (text_input['input_ids'][0][j:j+wlen] == widx).sum() == wlen: + for f in range(frames): + pww_maps[f,:,j:j+wlen,:,:] = layouts[f,i-1:i] # frames, seg_cls, c, h ,w = layouts.shape + cond_embeddings[0][j:j+wlen] = cond_embeddings[i][1:1+wlen] + print(prompts[i], i, '-th segment is handled.') + break + + # print("cond_embeddings",cond_embeddings) + creg_maps = {} + for r in range(4): + pww_maps_frames = [] + h = int(height/np.power(2,r)) + w = int(width/np.power(2,r)) + for i in range(frames): + pww_map_frame = pww_maps[i] + pww_map_frame.view(1,77,height,width) + pww_map_frame = F.interpolate(pww_map_frame, (h, w), mode='nearest') + pww_map_frame = pww_map_frame.view(1, 77, -1).permute(0, 2, 1).repeat(bsz,1,1) # 重新调整形状 + pww_maps_frames.append(pww_map_frame) + # 使用 torch.cat 连接处理后的所有帧 + layout_c = torch.stack(pww_maps_frames, dim=0) + # print("layout_c",layout_c) + creg_maps[h*w] = layout_c + + ########################### + #### prep for text_emb #### + ########################### + text_cond = torch.cat([uncond_embeddings, cond_embeddings[:1].repeat(bsz,1,1)]) + + return text_cond, sreg_maps, creg_maps, reg_sizes, reg_sizes_c + + + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + latent_mask: Union[torch.FloatTensor, PIL.Image.Image] = None, + layouts: Union[torch.FloatTensor, PIL.Image.Image] = None, + blending_percentage: float=0.25, + modulated_percentage: float=0.3, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = None, + num_inference_steps: int = 50, + clip_length: int = 8, + guidance_scale: float = 7.5, + source_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + control: Optional[torch.FloatTensor] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + logdir: str=None, + controlnet_conditioning_scale: float = 1.0, + use_pnp: bool = False, + attn_inversion_dict: dict=None, + **kwargs, + ): + + # 0. Default height and width to unet + t , c , height, width = image.shape + prompt = OmegaConf.to_container(prompt, resolve=True) + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps, strength) + + # 2. Define call parameters + batch_size = 1 + weight_dtype = image.dtype + device = self._execution_device + + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + + + if latents is None: + latents, attn_inversion_dict = self.prepare_latents_ddim_inverted( + image, batch_size, source_prompt, + do_classifier_free_guidance, generator, + control, controlnet_conditioning_scale, use_pnp + ) + print("use inversion latents") + + ## prepare text embedding, self attention map, cross attention map + _, _, _, downsample_height, downsample_width = latents.shape + + attention_type = self._get_attention_type() + text_cond, sreg_maps, creg_maps, reg_sizes,reg_sizes_c = self._prepare_attention_layout(batch_size,downsample_height,downsample_width, + layouts,prompt,clip_length,attention_type,device) + + time_steps = self.scheduler.timesteps + + #============do visualization for st-layout attn===============# + self.store_controller = attention_util.AttentionStore() + editor = ModulatedAttention_ControlEdit(text_cond=text_cond,sreg_maps=sreg_maps,creg_maps=creg_maps,reg_sizes=reg_sizes,reg_sizes_c=reg_sizes_c, + time_steps=time_steps,clip_length=clip_length,attention_type=attention_type, + additional_attention_store=self.store_controller, + save_self_attention = True, + disk_store = False, + video = image, + ) + attention_util.register_attention_control(self, editor, text_cond, clip_length, downsample_height,downsample_width,ddim_inversion=False) + #============do visualization for st-layout attn===============# + + # editor = ModulatedAttentionControl(text_cond=text_cond,sreg_maps=sreg_maps,creg_maps=creg_maps,reg_sizes=reg_sizes,reg_sizes_c=reg_sizes_c, + # time_steps=time_steps,clip_length=clip_length,attention_type=attention_type) + + # register_attention_control(self, editor, text_cond, clip_length,downsample_height,downsample_width,ddim_inversion=False) + + # 3. Encode input prompt + prompt = prompt[:1] + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + source_latents = self.prepare_source_latents( + image, batch_size, num_images_per_prompt, + # text_embeddings.dtype, device, + text_embeddings, + generator, + ) + + # 7. Denoising loop + num_warmup_steps = len(time_steps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps* (1-blending_percentage)) as progress_bar: + for i, t in enumerate(time_steps[int(len(time_steps) * blending_percentage):]): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # inject features + if use_pnp and i < kwargs["inject_step"]: + self.unet.up_blocks[1].resnets[0].out_layers_inject_features = attn_inversion_dict['features0'][i].to(device) + self.unet.up_blocks[1].resnets[1].out_layers_inject_features = attn_inversion_dict['features1'][i].to(device) + self.unet.up_blocks[2].resnets[0].out_layers_inject_features = attn_inversion_dict['features2'][i].to(device) + self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_q = attn_inversion_dict['q4'][i].to(device) + self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_k = attn_inversion_dict['k4'][i].to(device) + self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_q = attn_inversion_dict['q5'][i].to(device) + self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_k = attn_inversion_dict['k5'][i].to(device) + self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_q = attn_inversion_dict['q6'][i].to(device) + self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_k = attn_inversion_dict['k6'][i].to(device) + self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_q = attn_inversion_dict['q7'][i].to(device) + self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_k = attn_inversion_dict['k7'][i].to(device) + self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_q = attn_inversion_dict['q8'][i].to(device) + self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_k = attn_inversion_dict['k8'][i].to(device) + self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_q = attn_inversion_dict['q9'][i].to(device) + self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_k = attn_inversion_dict['k9'][i].to(device) + else: + self.clean_features() + + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + controlnet_cond=control, + return_dict=False, + ) + down_block_res_samples = [ + down_block_res_sample * controlnet_conditioning_scale + for down_block_res_sample in down_block_res_samples + ] + mid_block_res_sample *= controlnet_conditioning_scale + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + **kwargs, + ).sample.to(dtype=weight_dtype) + + + # perform guidance + if do_classifier_free_guidance: + # print("do_classifier_free_guidance") + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + + # Blending + noise_source_latents = self.scheduler.add_noise( + source_latents, torch.randn_like(latents), t + ) + + latents = latents * latent_mask + noise_source_latents * (1 - latent_mask) + + # call the callback, if provided + if i == len(time_steps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + ### vis cross attn + # image shape fchw + # save_path = os.path.join(logdir,'visualization_denoise') + # os.makedirs(save_path, exist_ok=True) + # attention_output = attention_util.show_cross_attention_plus_org_img(self.tokenizer,prompt, image, editor, 32, ["up","down"],save_path=save_path) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + has_nsfw_concept = None + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + torch.cuda.empty_cache() + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/video_diffusion/pipelines/stable_diffusion.py b/video_diffusion/pipelines/stable_diffusion.py new file mode 100755 index 0000000000000000000000000000000000000000..279754418c5d3da418074f0c5210c74e22ef466b --- /dev/null +++ b/video_diffusion/pipelines/stable_diffusion.py @@ -0,0 +1,619 @@ +# code mostly taken from https://github.com/huggingface/diffusers +import inspect +from typing import Callable, List, Optional, Union +import os, sys + +import torch +from einops import rearrange + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DDIMInverseScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import deprecate, logging +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ..models.unet_3d_condition import UNetPseudo3DConditionModel +from ..models.controlnet3d import ControlNetModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SpatioTemporalStableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Spatio-Temporal Stable Diffusion. + """ + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNetPseudo3DConditionModel, #UNet3DConditionModel, + controlnet: ControlNetModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + inverse_scheduler: DDIMInverseScheduler + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + inverse_scheduler=inverse_scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + def prepare_before_train_loop(self, params_to_optimize=None): + # Set xformers in train.py + + # self.disable_xformers_memory_efficient_attention() + + self.vae.requires_grad_(False) + self.unet.requires_grad_(False) + self.controlnet.requires_grad_(False) + self.text_encoder.requires_grad_(False) + + self.vae.eval() + self.unet.eval() + self.controlnet.requires_grad_(False) + self.text_encoder.eval() + + if params_to_optimize is not None: + params_to_optimize.requires_grad = True + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + print('prompt',prompt) + print('negative_prompt',negative_prompt) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents): + is_video = (latents.dim() == 5) + b = latents.shape[0] + latents = 1 / 0.18215 * latents + + if is_video: + latents = rearrange(latents, "b c f h w -> (b f) c h w") # torch.Size([70, 4, 64, 64]) + + latents_split = torch.split(latents, 16, dim=0) + image = torch.cat([self.vae.decode(l).sample for l in latents_split], dim=0) + + # image_full = self.vae.decode(latents).sample + # RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements + # Pytorch upsample alogrithm not work for batch size 32 -> 64 + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + + image = image.cpu().float().numpy() + if is_video: + image = rearrange(image, "(b f) c h w -> b f h w c", b=b) + else: + image = rearrange(image, "b c h w -> b h w c", b=b) + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + clip_length, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + clip_length, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to( + device + ) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + clip_length: int = 8, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + # [1, 4, 8, 64, 64] + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + clip_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents=None, + ) + latents_dtype = latents.dtype + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + # [2, 4, 8, 64, 64] + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=text_embeddings + ).sample.to(dtype=latents_dtype) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 [1, 4, 8, 64, 64] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + has_nsfw_concept = None + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + torch.cuda.empty_cache() + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + @staticmethod + def numpy_to_pil(images): + # (1, 16, 512, 512, 3) + pil_images = [] + is_video = (len(images.shape)==5) + print('is_video',is_video) + if is_video: + for sequence in images: + pil_images.append(DiffusionPipeline.numpy_to_pil(sequence)) + else: + pil_images.append(DiffusionPipeline.numpy_to_pil(images)) + return pil_images + + def print_pipeline(self, logger): + print('Overview function of pipeline: ') + print(self.__class__) + + print(self) + + expected_modules, optional_parameters = self._get_signature_keys(self) + components_details = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + import json + logger.info(str(components_details)) + # logger.info(str(json.dumps(components_details, indent = 4))) + # print(str(components_details)) + # print(self._optional_components) + + print(f"python version {sys.version}") + print(f"torch version {torch.__version__}") + print(f"validate gpu status:") + print( torch.tensor(1.0).cuda()*2) + os.system("nvcc --version") + + import diffusers + print(diffusers.__version__) + print(diffusers.__file__) + + try: + import bitsandbytes + print(bitsandbytes.__file__) + except: + print("fail to import bitsandbytes") + # os.system("accelerate env") + # os.system("python -m xformers.info") diff --git a/video_diffusion/pipelines/validation_loop.py b/video_diffusion/pipelines/validation_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd0e3c6a0fb86c9adf0942fb4eb3c8eeecdb22b --- /dev/null +++ b/video_diffusion/pipelines/validation_loop.py @@ -0,0 +1,210 @@ +import os +import numpy as np +from typing import List, Union +import PIL + + +import torch +import torch.utils.data +import torch.utils.checkpoint + +from diffusers.pipeline_utils import DiffusionPipeline +from tqdm.auto import tqdm +from video_diffusion.common.image_util import make_grid, annotate_image +from video_diffusion.common.image_util import save_gif_mp4_folder_type +import cv2 + +class SampleLogger: + def __init__( + self, + editing_prompts: List[str], + clip_length: int, + logdir: str, + subdir: str = "sample", + num_samples_per_prompt: int = 1, + sample_seeds: List[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 7, + strength: float = None, + annotate: bool = False, + annotate_size: int = 15, + make_grid: bool = True, + grid_column_size: int = 2, + layout_mask_dir: str = None, # New parameter for the layout mask directory + layouts_masks_orders: List[str]=None, + stride: int = 1, + n_sample_frame: int = 8, + start_sample_frame: int = None, + sampling_rate: int = 1, + **args + + ) -> None: + self.editing_prompts = editing_prompts + self.clip_length = clip_length + self.guidance_scale = guidance_scale + self.num_inference_steps = num_inference_steps + self.strength = strength + + if sample_seeds is None: + max_num_samples_per_prompt = int(1e5) + if num_samples_per_prompt > max_num_samples_per_prompt: + raise ValueError + sample_seeds = torch.randint(0, max_num_samples_per_prompt, (num_samples_per_prompt,)) + sample_seeds = sorted(sample_seeds.numpy().tolist()) + self.sample_seeds = sample_seeds + + self.logdir = os.path.join(logdir, subdir) + os.makedirs(self.logdir, exist_ok=True) + + self.annotate = annotate + self.annotate_size = annotate_size + self.make_grid = make_grid + self.grid_column_size = grid_column_size + + + self.layout_mask_dir = layout_mask_dir # Initialize layout_mask_dir + self.layout_mask_orders = layouts_masks_orders + self.stride = stride + self.n_sample_frame = n_sample_frame + self.start_sample_frame = start_sample_frame + self.sampling_rate = sampling_rate + + + def _read_mask(self, mask_path, index: int, dest_size=(64, 64)): + mask_path = os.path.join(mask_path, f"{index:05d}.png") + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + mask = (mask > 0).astype(np.uint8) + mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) + mask = mask[np.newaxis, ...] + return mask + + def get_frame_indices(self, index): + if self.start_sample_frame is not None: + frame_start = self.start_sample_frame + self.stride * index + else: + frame_start = self.stride * index + return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame)) + + def read_layout_and_merge_masks(self, index): + + layouts_all, masks_all = [],[] + for idx,layout_mask_order_per in enumerate(self.layout_mask_orders): + layout_ = [] + for layout_name in layout_mask_order_per: # Loop over prompts + frame_indices = self.get_frame_indices(index % self.clip_length) + layout_mask_dir = os.path.join(self.layout_mask_dir, layout_name) + mask = [self._read_mask(layout_mask_dir, i) for i in frame_indices] + masks = np.stack(mask) + layout_.append(masks) + layout_ = np.stack(layout_) + + merged_masks = [] + for i in range(int(self.n_sample_frame)): + merged_mask_frame = np.sum(layout_[:, i, :, :], axis=0) + merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8) + merged_masks.append(merged_mask_frame) + + masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w") + masks = torch.from_numpy(masks).half() + + layouts = rearrange(layout_, "s f c h w -> f s c h w") + layouts = torch.from_numpy(layouts).half() + layouts_all.append(layouts) + masks_all.append(mask) + return masks_all, layouts_all + + def log_sample_images( + self, pipeline: DiffusionPipeline, + device: torch.device, step: int, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + masks: Union[torch.FloatTensor, PIL.Image.Image] = None, + layouts : Union[torch.FloatTensor, PIL.Image.Image] = None, + latents: torch.FloatTensor = None, + control: torch.FloatTensor = None, + controlnet_conditioning_scale = None, + negative_prompt: Union[str, List[str]] = None, + blending_percentage = None, + trajs = None, + flatten_res = None, + source_prompt = None, + inject_step = None, + old_qk = None, + use_pnp = None, + attn_inversion_dict = None, + ): + torch.cuda.empty_cache() + samples_all = [] + attention_all = [] + # handle input image + if image is not None: + input_pil_images = pipeline.numpy_to_pil(tensor_to_numpy(image))[0] + samples_all.append(input_pil_images) + # samples_all.append([ + # annotate_image(image, "input sequence", font_size=self.annotate_size) for image in input_pil_images + # ]) + #masks_all, layouts_all = self.read_layout_and_merge_masks() + #for idx, (prompt, masks, layouts) in enumerate(tqdm(zip(self.editing_prompts, masks_all, layouts_all), desc="Generating sample images")): + for idx, prompt in enumerate(tqdm(self.editing_prompts, desc="Generating sample images")): + for seed in self.sample_seeds: + generator = torch.Generator(device=device) + generator.manual_seed(seed) + sequence_return = pipeline( + prompt=prompt, + image=image, # torch.Size([8, 3, 512, 512]) + latent_mask=masks, + layouts = layouts, + strength=self.strength, + generator=generator, + num_inference_steps=self.num_inference_steps, + clip_length=self.clip_length, + guidance_scale=self.guidance_scale, + num_images_per_prompt=1, + # used in null inversion + control = control, + controlnet_conditioning_scale = controlnet_conditioning_scale, + latents = latents, + #uncond_embeddings_list = uncond_embeddings_list, + blending_percentage = blending_percentage, + logdir = self.logdir, + trajs = trajs, + flatten_res = flatten_res, + negative_prompt=negative_prompt, + source_prompt=source_prompt, + inject_step=inject_step, + old_qk=old_qk, + use_pnp=use_pnp, + attn_inversion_dict=attn_inversion_dict, + # Put the source prompt at the first one, when using p2p + ) + + sequence = sequence_return.images[0] + torch.cuda.empty_cache() + + if self.annotate: + images = [ + annotate_image(image, prompt, font_size=self.annotate_size) for image in sequence + ] + else: + images = sequence + + if self.make_grid: + samples_all.append(images) + save_path = os.path.join(self.logdir, f"step_{step}_{idx}_{seed}.gif") + save_gif_mp4_folder_type(images, save_path) + + if self.make_grid: + samples_all = [make_grid(images, cols=int(np.ceil(np.sqrt(len(samples_all))))) for images in zip(*samples_all)] + save_path = os.path.join(self.logdir, f"step_{step}.gif") + save_gif_mp4_folder_type(samples_all, save_path) + return samples_all + + +from einops import rearrange + +def tensor_to_numpy(image, b=1): + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + + image = image.cpu().float().numpy() + image = rearrange(image, "(b f) c h w -> b f h w c", b=b) + return image \ No newline at end of file diff --git a/video_diffusion/prompt_attention/__pycache__/attention_register.cpython-310.pyc b/video_diffusion/prompt_attention/__pycache__/attention_register.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35a9faa2ce1929b11f6a5a5e2fe7c0dbec55f471 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/attention_register.cpython-310.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/attention_register.cpython-38.pyc b/video_diffusion/prompt_attention/__pycache__/attention_register.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41b6c51c425461d70d7c89ebe8cc123138d0917b Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/attention_register.cpython-38.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/attention_register.cpython-39.pyc b/video_diffusion/prompt_attention/__pycache__/attention_register.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a647b62458ca393de02ecf84547196d1918fa9c7 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/attention_register.cpython-39.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/attention_store.cpython-310.pyc b/video_diffusion/prompt_attention/__pycache__/attention_store.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82a16f46138bdbf4a9657668e8501c44851512da Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/attention_store.cpython-310.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/attention_store.cpython-38.pyc b/video_diffusion/prompt_attention/__pycache__/attention_store.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e15145e214022b76e01f5b32808111d54300ac2f Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/attention_store.cpython-38.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/attention_store.cpython-39.pyc b/video_diffusion/prompt_attention/__pycache__/attention_store.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f48c91d56bec35671aedf6ac02d52da536c70553 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/attention_store.cpython-39.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/attention_util.cpython-310.pyc b/video_diffusion/prompt_attention/__pycache__/attention_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b51e86bfd76636691f774af4243150b8053e7ff2 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/attention_util.cpython-310.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/attention_util.cpython-38.pyc b/video_diffusion/prompt_attention/__pycache__/attention_util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c11798eae6fe00fa32f87b38078056e3c8dc90d2 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/attention_util.cpython-38.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/attention_util.cpython-39.pyc b/video_diffusion/prompt_attention/__pycache__/attention_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dd3a09420f988abd93fb5b821dde264b87bd81a Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/attention_util.cpython-39.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/free_lunch_utils.cpython-310.pyc b/video_diffusion/prompt_attention/__pycache__/free_lunch_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05d59dea9d9f17c7ea062d083482dc81b5e94003 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/free_lunch_utils.cpython-310.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/ptp_utils.cpython-310.pyc b/video_diffusion/prompt_attention/__pycache__/ptp_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c599a8d4c2469fccc5e1f151fe9f49d4362b6de Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/ptp_utils.cpython-310.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/ptp_utils.cpython-38.pyc b/video_diffusion/prompt_attention/__pycache__/ptp_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..333ecadbeb733a5326a24a9731a1eb1be10ce02c Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/ptp_utils.cpython-38.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/ptp_utils.cpython-39.pyc b/video_diffusion/prompt_attention/__pycache__/ptp_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c20a7e8ade81da12e4866b6ebc477fbe63dcaf98 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/ptp_utils.cpython-39.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/sd_study_utils.cpython-310.pyc b/video_diffusion/prompt_attention/__pycache__/sd_study_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93cd5a7376204949f1656a3897ec108239eeee87 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/sd_study_utils.cpython-310.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/seq_aligner.cpython-310.pyc b/video_diffusion/prompt_attention/__pycache__/seq_aligner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df90593a708694cddd245aea7d6a8ae19153a89e Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/seq_aligner.cpython-310.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/seq_aligner.cpython-38.pyc b/video_diffusion/prompt_attention/__pycache__/seq_aligner.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b2eb2e3959bd3b71fabde370522d3c760f761d7 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/seq_aligner.cpython-38.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/seq_aligner.cpython-39.pyc b/video_diffusion/prompt_attention/__pycache__/seq_aligner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..979236698ab944563c3b7a6bfce8a01a02290491 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/seq_aligner.cpython-39.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/spatial_blend.cpython-310.pyc b/video_diffusion/prompt_attention/__pycache__/spatial_blend.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..081edd26c8ec62f5f8f5716212c02adf813637ce Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/spatial_blend.cpython-310.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/spatial_blend.cpython-38.pyc b/video_diffusion/prompt_attention/__pycache__/spatial_blend.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7409b81932cca57dd450f1616ebea06368624a7 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/spatial_blend.cpython-38.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/spatial_blend.cpython-39.pyc b/video_diffusion/prompt_attention/__pycache__/spatial_blend.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..740d3cb3cc8fbb1ee49eb8ddfc039a7cafc1b607 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/spatial_blend.cpython-39.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/visualization.cpython-310.pyc b/video_diffusion/prompt_attention/__pycache__/visualization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9daf70513afa81ef0f353499098c4e5057d9e59 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/visualization.cpython-310.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/visualization.cpython-38.pyc b/video_diffusion/prompt_attention/__pycache__/visualization.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67346dd2ff3c26d46c5642ee18a0bcfa91f85130 Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/visualization.cpython-38.pyc differ diff --git a/video_diffusion/prompt_attention/__pycache__/visualization.cpython-39.pyc b/video_diffusion/prompt_attention/__pycache__/visualization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9bbdedb2ef2428ac0a8f9d2012183bb83ed552f Binary files /dev/null and b/video_diffusion/prompt_attention/__pycache__/visualization.cpython-39.pyc differ diff --git a/video_diffusion/prompt_attention/attention_register.py b/video_diffusion/prompt_attention/attention_register.py new file mode 100644 index 0000000000000000000000000000000000000000..02e83b4644d2b2f06c32a62e5bd626b8447d14fe --- /dev/null +++ b/video_diffusion/prompt_attention/attention_register.py @@ -0,0 +1,556 @@ +""" +register the attention controller into the UNet of stable diffusion +Build a customized attention function `_attention' +Replace the original attention function with `forward' and `spatial_temporal_forward' in attention_controlled_forward function +Most of spatial_temporal_forward is directly copy from `video_diffusion/models/attention.py' +TODO FIXME: merge redundant code with attention.py +""" + +from einops import rearrange +import torch +import torch.nn.functional as F +import math +from diffusers.utils.import_utils import is_xformers_available +import numpy as np + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +def register_attention_control(model, controller, text_cond, clip_length, height, width, ddim_inversion): + "Connect a model with a controller" + def attention_controlled_forward(self, place_in_unet, attention_type='cross'): + to_out = self.to_out + if type(to_out) is torch.nn.modules.container.ModuleList: + to_out = self.to_out[0] + else: + to_out = self.to_out + + def _attention(query, key, value, is_cross, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + # print("query",query.shape) + # print("key",key.shape) + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + #print("attention_scores",attention_scores.shape) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + # START OF CORE FUNCTION + # if not ddim_inversion: + attention_probs = controller(reshape_batch_dim_to_temporal_heads(attention_scores), + is_cross, place_in_unet) + attention_probs = reshape_temporal_heads_to_batch_dim(attention_probs) + # END OF CORE FUNCTION + + attention_probs = attention_probs.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def reshape_temporal_heads_to_batch_dim(tensor): + head_size = self.heads + tensor = rearrange(tensor, " b h s t -> (b h) s t ", h = head_size) + return tensor + + def reshape_batch_dim_to_temporal_heads(tensor): + head_size = self.heads + tensor = rearrange(tensor, "(b h) s t -> b h s t", h = head_size) + return tensor + + def reshape_heads_to_batch_dim3(tensor): + batch_size1, batch_size2, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size1, batch_size2, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 3, 1, 2, 4) + return tensor + + def reshape_heads_to_batch_dim(tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + + def reshape_batch_dim_to_heads(tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def _memory_efficient_attention_xformers(query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): + # hidden_states: torch.Size([16, 4096, 320]) + # encoder_hidden_states: torch.Size([16, 77, 768]) + is_cross = encoder_hidden_states is not None + + #encoder_hidden_states = encoder_hidden_states + + text_cond_frames = text_cond.repeat_interleave(clip_length, 0) # wrong implementation text_cond.repeat(clip_length,1,1) + + ######for debug###### + # text_cond_repeat_interleave = text_cond.repeat_interleave(clip_length, 0) + # print("after repeat interleave", text_cond_repeat_interleave.shape, text_cond_repeat_interleave.view(-1)[:20]) + # text_cond_repeat = text_cond.repeat(clip_length,1,1) + # print("First 20 elements after repeat:", text_cond_repeat.shape, text_cond_repeat.view(-1)[:20]) + ######for debug###### + + encoder_hidden_states = text_cond_frames + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + query = reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = reshape_heads_to_batch_dim(key) + value = reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = text_cond_frames if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = reshape_heads_to_batch_dim(key) + value = reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + if self._use_memory_efficient_attention_xformers and query.shape[-2] > ((height//2) * (width//2)): + # for large attention map of 64X64, use xformers to save memory + hidden_states = _memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + + hidden_states = _attention(query, key, value, is_cross=is_cross, attention_mask=attention_mask) + # else: + # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + #dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def spatial_temporal_forward( + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + clip_length: int = None, + SparseCausalAttention_index: list = [-1, 'first'] #list = [0] + ): + """ + Most of spatial_temporal_forward is directly copy from `video_diffusion.models.attention.SparseCausalAttention' + We add two modification + 1. use self defined attention function that is controlled by AttentionControlEdit module + 2. remove the dropout to reduce randomness + FIXME: merge redundant code with attention.py + + """ + if ( + self.added_kv_proj_dim is not None + or encoder_hidden_states is not None + or attention_mask is not None + ): + raise NotImplementedError + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + + query = reshape_heads_to_batch_dim(query) + + + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + if clip_length is not None: + key = rearrange(key, "(b f) d c -> b f d c", f=clip_length) + value = rearrange(value, "(b f) d c -> b f d c", f=clip_length) + + + # *********************** Start of Spatial-temporal attention ********** + frame_index_list = [] + + if len(SparseCausalAttention_index) > 0: + for index in SparseCausalAttention_index: + if isinstance(index, str): + if index == 'first': + frame_index = [0] * clip_length + if index == 'last': + frame_index = [clip_length-1] * clip_length + if (index == 'mid') or (index == 'middle'): + frame_index = [int((clip_length-1)//2)] * clip_length + else: + assert isinstance(index, int), 'relative index must be int' + frame_index = torch.arange(clip_length) + index + frame_index = frame_index.clip(0, clip_length-1) + + frame_index_list.append(frame_index) + # print("frame_index_list",frame_index_list) [bz, frame, 4096, 320] + + key = torch.cat([ key[:, frame_index] for frame_index in frame_index_list #[bz, frame, 8192, 320]) + ], dim=2) + value = torch.cat([ value[:, frame_index] for frame_index in frame_index_list + ], dim=2) + + + # *********************** End of Spatial-temporal attention ********** + key = rearrange(key, "b f d c -> (b f) d c", f=clip_length) + value = rearrange(value, "b f d c -> (b f) d c", f=clip_length) + # print("key after rearrange",key.shape) + # print("value after rearrange",value.shape) + + key = reshape_heads_to_batch_dim(key) + value = reshape_heads_to_batch_dim(value) + + # print("query after head to batch dim",query.shape) + # print("key after head to batch dim",key.shape) + + if torch.isnan(query.reshape(-1)[0]): + print("nan value query",query.reshape(-1)[:10]) + print("nan value key",key.reshape(-1)[:10]) + exit() + + # print("query after reshape heads to batch ",query.shape) + # print("key after reshape heads to batch",key.shape) + + if self._use_memory_efficient_attention_xformers and query.shape[-2] > ((height//2) * (width//2)): + # FIXME there should be only one variable to control whether use xformers + # if self._use_memory_efficient_attention_xformers: + # for large attention map of 64X64, use xformers to save memory + hidden_states = _memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + # if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = _attention(query, key, value, attention_mask=attention_mask, is_cross=False) + # else: + # hidden_states = self._sliced_attention( + # query, key, value, hidden_states.shape[1], dim, attention_mask + # ) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _sliced_attention(query, key, value, sequence_length, dim, attention_mask): + #query (bz*heads, t x h x w, org_dim//heads ) + + is_cross = False + batch_size_attention = query.shape[0] # bz * heads + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + + if ddim_inversion: + per_frame_len = sequence_length//clip_length + attention_store = torch.zeros((batch_size_attention, clip_length, per_frame_len, per_frame_len), device=query.device, dtype=query.dtype) + + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + if i < self.heads: + if not ddim_inversion: + attention_probs = controller((attn_slice.unsqueeze(1)),is_cross, place_in_unet) + attn_slice = attention_probs.squeeze(1) + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + ## bz == 1, sliced head + if ddim_inversion: + # attn_slice (1, thw, thw) + bz, thw, thw = attn_slice.shape + t = clip_length + hw = thw // t + # 初始化 per_frame_attention + # (1, t, hxw) + + per_frame_attention = torch.empty((bz, t, hw, hw), device=attn_slice.device) + + # # 循环提取每一帧的对角线注意力 + for idx in range(t): + start_idx_ = idx * hw + end_idx_ = (idx + 1) * hw + # per frame attention extraction + per_frame_attention[:, idx, :, :] = attn_slice[:, start_idx_:end_idx_, start_idx_:end_idx_] + + # current_query_block = attn_slice[:, start_idx_:end_idx_, :] + # aggregated_attention = current_query_block.view(bz, hw, t, hw).mean(dim=2) + # # print('aggregated_attention',aggregated_attention.shape) + # per_frame_attention[:, idx, :, :] = aggregated_attention + + per_frame_attention = rearrange(per_frame_attention, "b t h w -> (b t) h w") + attention_store[start_idx:end_idx] = per_frame_attention + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + if ddim_inversion: + # attention store (bz*heads, t , h, w) h=res, w=res + _ = controller(attention_store, is_cross, place_in_unet) + + # reshape hidden_states + hidden_states = reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + + def fully_frame_forward(hidden_states, encoder_hidden_states=None, attention_mask=None, clip_length=None, inter_frame=False, **kwargs): + batch_size, sequence_length, _ = hidden_states.shape + # print("hidden_states.shape",hidden_states.shape) + # print("sequence_length",sequence_length) + + encoder_hidden_states = encoder_hidden_states + h = kwargs['height'] + w = kwargs['width'] + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # (bf) x d(hw) x c + self.q = query + if self.inject_q is not None: + query = self.inject_q + dim = query.shape[-1] + query_old = query.clone() + + # All frames + #init query (bz*t, hxw, dim) + query = rearrange(query, "(b f) d c -> b (f d) c", f=clip_length) + query = reshape_heads_to_batch_dim(query) #(bz*heads, txhxw, dim//heads) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + self.k = key + if self.inject_k is not None: + key = self.inject_k + key_old = key.clone() + value = self.to_v(encoder_hidden_states) + + if inter_frame: + key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)[:, [0, -1]] + value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)[:, [0, -1]] + key = rearrange(key, "b f d c -> b (f d) c",) + value = rearrange(value, "b f d c -> b (f d) c") + else: + # All frames + key = rearrange(key, "(b f) d c -> b (f d) c", f=clip_length) + value = rearrange(value, "(b f) d c -> b (f d) c", f=clip_length) + + key = reshape_heads_to_batch_dim(key) + value = reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + #print("query.shape[0]",query.shape[0]) # 16 + self._slice_size = 1 ### 8 + sequence_length_full_frame = query.shape[1] + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers and query.shape[-2] > clip_length*(32 ** 2): + hidden_states = _memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + + else: + # if ddim_inversion: + # #if self._slice_size is None or query.shape[0] // self._slice_size == 1: + # hidden_states = _attention(query, key, value, attention_mask) + # else: + hidden_states = _sliced_attention(query, key, value, sequence_length_full_frame, dim, attention_mask) + + if [h,w] in kwargs['flatten_res']: + hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=clip_length) + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if kwargs["old_qk"] == 1: + query = query_old + key = key_old + else: + query = hidden_states + key = hidden_states + value = hidden_states + + traj = kwargs["traj"] + traj = rearrange(traj, '(f n) l d -> f n l d', f=clip_length, n=sequence_length) + mask = rearrange(kwargs["mask"], '(f n) l -> f n l', f=clip_length, n=sequence_length) + mask = torch.cat([mask[:, :, 0].unsqueeze(-1), mask[:, :, -clip_length+1:]], dim=-1) + + #print('traj',traj.shape) + #print('mask',mask.shape) + + traj_key_sequence_inds = torch.cat([traj[:, :, 0, :].unsqueeze(-2), traj[:, :, -clip_length+1:, :]], dim=-2) + t_inds = traj_key_sequence_inds[:, :, :, 0] + x_inds = traj_key_sequence_inds[:, :, :, 1] + y_inds = traj_key_sequence_inds[:, :, :, 2] + + query_tempo = query.unsqueeze(-2) + _key = rearrange(key, '(b f) (h w) d -> b f h w d', b=int(batch_size/clip_length), f=clip_length, h=h, w=w) + _value = rearrange(value, '(b f) (h w) d -> b f h w d', b=int(batch_size/clip_length), f=clip_length, h=h, w=w) + key_tempo = _key[:, t_inds, x_inds, y_inds] + value_tempo = _value[:, t_inds, x_inds, y_inds] + key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d') + value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d') + + mask = rearrange(torch.stack([mask, mask]), 'b f n l -> (b f) n l') + mask = mask[:,None].repeat(1, self.heads, 1, 1).unsqueeze(-2) + attn_bias = torch.zeros_like(mask, dtype=key_tempo.dtype) # regular zeros_like + attn_bias[~mask] = -torch.inf + + # print('attn_bias',attn_bias.shape) + # print('query_tempo',query_tempo.shape) + # print('key_tempo',key_tempo.shape) + + # flow attention + query_tempo = reshape_heads_to_batch_dim3(query_tempo) + key_tempo = reshape_heads_to_batch_dim3(key_tempo) + value_tempo = reshape_heads_to_batch_dim3(value_tempo) + + attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt(query_tempo.size(-1)) + attn_bias + attn_matrix2 = F.softmax(attn_matrix2, dim=-1) + out = (attn_matrix2@value_tempo).squeeze(-2) + + hidden_states = rearrange(out,'(b f) k (h w) d -> b (f h w) (k d)', b=int(batch_size/clip_length), f=clip_length, h=h, w=w) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + # All frames + hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=clip_length) + return hidden_states + + + if attention_type == 'CrossAttention': + # return mod_forward + return forward + elif attention_type == "SparseCausalAttention": + #return mod_forward + return spatial_temporal_forward + elif attention_type == "FullyFrameAttention": + #return mod_forward + return fully_frame_forward + + class DummyController: + + def __call__(self, *args): + return args[0] + + def __init__(self): + self.num_att_layers = 0 + + if controller is None: + controller = DummyController() + + def register_recr(net_, count, place_in_unet): + if net_[1].__class__.__name__ == 'CrossAttention' \ + or net_[1].__class__.__name__ == 'FullyFrameAttention' \ + or net_[1].__class__.__name__ == 'SparseCausalAttention' : + net_[1].forward = attention_controlled_forward(net_[1], place_in_unet, attention_type = net_[1].__class__.__name__) + return count + 1 + elif hasattr(net_[1], 'children'): + for net in net_[1].named_children(): + if net[0] !='attn_temporal': + + count = register_recr(net, count, place_in_unet) + + return count + + cross_att_count = 0 + sub_nets = model.unet.named_children() + for net in sub_nets: + if "down" in net[0]: + cross_att_count += register_recr(net, 0, "down") + elif "up" in net[0]: + cross_att_count += register_recr(net, 0, "up") + elif "mid" in net[0]: + cross_att_count += register_recr(net, 0, "mid") + #print(f"Number of attention layer registered {cross_att_count}") + controller.num_att_layers = cross_att_count diff --git a/video_diffusion/prompt_attention/attention_store.py b/video_diffusion/prompt_attention/attention_store.py new file mode 100644 index 0000000000000000000000000000000000000000..edb0611830f732105b9a1f757c5a5295e5b54b5a --- /dev/null +++ b/video_diffusion/prompt_attention/attention_store.py @@ -0,0 +1,132 @@ +""" +Code of attention storer AttentionStore, which is a base class for attention editor in attention_util.py + +""" + +import abc +import os +import copy +import torch +from video_diffusion.common.util import get_time_string +from einops import rearrange +from typing import Any, Callable, Dict, List, Optional, Union + +class AttentionControl(abc.ABC): + + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + @property + def num_uncond_att_layers(self): + """I guess the diffusion of google has some unconditional attention layer + No unconditional attention layer in Stable diffusion + + Returns: + _type_: _description_ + """ + # return self.num_att_layers if config_dict['LOW_RESOURCE'] else 0 + return 0 + + @abc.abstractmethod + def forward (self, attn, is_cross: bool, place_in_unet: str): + return attn + # raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + # For classifier-free guidance scale!=1 + #print("half forward") + h = attn.shape[0] + if h == 1: + #print("sliced attn") + attn = self.forward(attn, is_cross, place_in_unet) + self.sliced_attn_head_count+=1 + if self.sliced_attn_head_count == 8: + self.cur_att_layer += 1 + self.sliced_attn_head_count = 0 + else: + attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers-10: + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + + return attn + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def __init__(self, + ): + self.LOW_RESOURCE = False # assume the edit have cfg + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + self.sliced_attn_head_count = 0 + + + +class AttentionStore(AttentionControl): + + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], + "down_self": [], "mid_self": [], "up_self": []} + + @staticmethod + def get_empty_cross_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], + "down_self": [], "mid_self": [], "up_self": [] + } + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[2] <= 32 ** 2: + # if not is_cross: + append_tensor = attn.cpu().detach() + self.step_store[key].append(copy.deepcopy(append_tensor)) + + return attn + + def between_steps(self): + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + + self.step_store = self.get_empty_store() + + def get_average_attention(self): + "divide the attention map value in attention store by denoising steps" + average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} + return average_attention + + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_cross_store() + self.attention_store_all_step = [] + self.attention_store = {} + + def __init__(self, save_self_attention:bool=True, disk_store=False): + super(AttentionStore, self).__init__() + self.disk_store = disk_store + if self.disk_store: + time_string = get_time_string() + path = f'./trash/attention_cache_{time_string}' + os.makedirs(path, exist_ok=True) + self.store_dir = path + else: + self.store_dir =None + self.step_store = self.get_empty_store() + self.attention_store = {} + self.save_self_attention = save_self_attention + self.latents_store = [] + self.attention_store_all_step = [] diff --git a/video_diffusion/prompt_attention/attention_util.py b/video_diffusion/prompt_attention/attention_util.py new file mode 100644 index 0000000000000000000000000000000000000000..38ea7ccbe94b26bb3e9d84d202dd3a455ded101d --- /dev/null +++ b/video_diffusion/prompt_attention/attention_util.py @@ -0,0 +1,423 @@ +""" +Collect all function in prompt_attention folder. +Provide a API `make_controller' to return an initialized AttentionControlEdit class object in the main validation loop. +""" + +from typing import Optional, Union, Tuple, List, Dict +import abc +import numpy as np +import copy +from einops import rearrange + +import torch +import torch.nn.functional as F + +import video_diffusion.prompt_attention.ptp_utils as ptp_utils +from video_diffusion.prompt_attention.visualization import show_cross_attention,show_cross_attention_plus_org_img,show_self_attention_comp +from video_diffusion.prompt_attention.attention_store import AttentionStore, AttentionControl +from video_diffusion.prompt_attention.attention_register import register_attention_control +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +from PIL import Image +import os +from video_diffusion.common.image_util import save_gif_mp4_folder_type,make_grid +import cv2 +import math + +from PIL import Image, ImageDraw +import numpy as np +import math +import os + +class EmptyControl: + + + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + return attn + + +def apply_jet_colormap(weight): + # 将权重规范化到0-255 + weight = 255*(weight - weight.min()) / (weight.max() - weight.min()+1e-6) + weight = weight.astype(np.uint8) + + # 应用Jet颜色映射 + color_mapped_weight = cv2.applyColorMap(weight, cv2.COLORMAP_JET) + return color_mapped_weight + +def show_self_attention_comp(self_attention_map, video, h_index:int, w_index:int, res: int, frames:int, place_in_unet: List[str], step:int ): + + attention_maps = self_attention_map.reshape(frames, res, res, frames, res, res) + weights = attention_maps[0,h_index,w_index,:,:,:] + attention_list = [] + video_frames = [] + #video f,c,h,w + + for i in range(frames): + weight = weights[i].cpu().numpy() + weight_colored = apply_jet_colormap(weight) + weight_colored = weight_colored[:, :, ::-1] # BGR到RGB的转换 + weight_colored = np.array(Image.fromarray(weight_colored).resize((256, 256))) + attention_list.append(weight_colored) + + frame = video[i].permute(1,2,0).cpu().numpy() + mean = np.array((0.48145466, 0.4578275, 0.40821073)).reshape((1, 1, 3)) # [h, w, c] + varas = np.array((0.26862954, 0.26130258, 0.27577711)).reshape((1, 1, 3)) + frame = frame * varas + mean + frame = (frame - frame.min()) / (frame.max() - frame.min() + 1e-6) * 255 + frame = frame.astype(np.uint8) + video_frames.append(frame) + + alpha = 0.5 + overlay_frames = [] + + for frame, attention in zip(video_frames, attention_list): + + attention_resized = cv2.resize(attention, (frame.shape[1], frame.shape[0])) + + overlay_frame = cv2.addWeighted(frame, alpha, attention_resized, 1 - alpha, 0) + + overlay_frames.append(overlay_frame) + print('vis self attn') + save_path = "with_st_layout_vis_self_attn/vis_self_attn" + os.makedirs(save_path, exist_ok=True) + video_save_path = f'{save_path}/self-attn-{place_in_unet}-{step}-query-frame0-h{h_index}-w{w_index}.gif' + save_gif_mp4_folder_type(overlay_frames, video_save_path,save_gif=False) + + +def draw_grid_on_image(image, grid_size, line_color="gray"): + draw = ImageDraw.Draw(image) + w, h = image.size + for i in range(0, w, grid_size): + draw.line([(i, 0), (i, h)], fill=line_color) + for i in range(0, h, grid_size): + draw.line([(0, i), (w, i)], fill=line_color) + return image + + +def identify_self_attention_max_min(sim, video, h_index:int, w_index:int, res: int, frames:int, place_in_unet: str, step:int): + attention_maps = sim.reshape(frames, res, res, frames, res, res) + weights = attention_maps[0, h_index, w_index, :, :, :] + + flattened_weights = weights.reshape(-1) + global_max_index = flattened_weights.argmax().cpu().numpy() + global_min_index = flattened_weights.argmin().cpu().numpy() + print('weights.shape',weights.shape) + + frame_max, h_max, w_max = np.unravel_index(global_max_index, weights.shape) + frame_min, h_min, w_min = np.unravel_index(global_min_index, weights.shape) + + video_frames = [] + + query_frame_index = 0 + query_h = h_index + query_w = w_index + + for i in range(frames): + frame = video[i].permute(1, 2, 0).cpu().numpy() + mean = np.array((0.48145466, 0.4578275, 0.40821073)).reshape((1, 1, 3)) + varas = np.array((0.26862954, 0.26130258, 0.27577711)).reshape((1, 1, 3)) + frame = (frame * varas + mean) * 255 + frame = np.clip(frame, 0, 255).astype(np.uint8) + frame_img = Image.fromarray(frame) + + grid_size = 512 // res + frame_img = draw_grid_on_image(frame_img, grid_size) + + draw = ImageDraw.Draw(frame_img) + if i == frame_max: + max_pixel_pos = (w_max * grid_size, h_max * grid_size) + draw.rectangle([max_pixel_pos, (max_pixel_pos[0] + grid_size, max_pixel_pos[1] + grid_size)], outline="red", width=2) + if i == frame_min: + min_pixel_pos = (w_min * grid_size, h_min * grid_size) + draw.rectangle([min_pixel_pos, (min_pixel_pos[0] + grid_size, min_pixel_pos[1] + grid_size)], outline="blue", width=2) + + if i == query_frame_index: + query_pixel_pos = (query_w * grid_size, query_h * grid_size) + draw.rectangle([query_pixel_pos, (query_pixel_pos[0] + grid_size, query_pixel_pos[1] + grid_size)], outline="yellow", width=2) + + video_frames.append(frame_img) + + save_path = "/visualization/correspondence_with_query" + os.makedirs(save_path, exist_ok=True) + video_save_path = os.path.join(save_path, f'self-attn-{place_in_unet}-{step}-query-frame0-h{h_index}-w{w_index}.gif') + + save_gif_mp4_folder_type(video_frames, video_save_path, save_gif=False) + + + + +class ModulatedAttentionControl(AttentionControl, abc.ABC): + + def __init__(self, end_step=15, total_steps=50, step_idx=None, text_cond=None, sreg_maps=None, creg_maps=None, reg_sizes=None,reg_sizes_c=None, time_steps=None,clip_length=None,attention_type=None): + """ + Mutual self-attention control for Stable-Diffusion model + Args: + start_step: the step to start mutual self-attention control + start_layer: the layer to start mutual self-attention control + layer_idx: list of the layers to apply mutual self-attention control + step_idx: list the steps to apply mutual self-attention control + total_steps: the total number of steps + model_type: the model type, SD or SDXL + """ + super().__init__() + self.total_steps = total_steps + self.step_idx = list(range(0, end_step)) + self.total_infer_steps = 50 + self.text_cond = text_cond + self.sreg_maps = sreg_maps + self.creg_maps = creg_maps + self.reg_sizes = reg_sizes + self.reg_sizes_c = reg_sizes_c + self.clip_length = clip_length + self.attention_type = attention_type + self.sreg = .3 + self.creg = 1. + self.count = 0 + self.reg_part = .3 + self.time_steps = time_steps + print("Modulated Ctrl at denoising steps: ", self.step_idx) + + def forward(self, sim, is_cross, place_in_unet, **kwargs): + """ + Attention forward function + """ + #print("self.cur_step",self.cur_step) + + if self.cur_step not in self.step_idx: + return super().forward(sim, is_cross, place_in_unet, **kwargs) + + + ### sim for "SparseCausalAttention": (frames, heads=8,res, 2*res) + ### sim for "FullyFrameAttention" : 1, heads, frame*res,frane*res [1, 8, 12288, 12288]) + num_heads = sim.shape[1] + if num_heads == 1: + self.attention_type == "FullyFrameAttention_sliced_attn" + + treg = torch.pow((self.time_steps[self.cur_step]-1)/1000, 5) + + + if not is_cross: + min_value = sim.min(-1)[0].unsqueeze(-1) + max_value = sim.max(-1)[0].unsqueeze(-1) + if self.attention_type == "SparseCausalAttention": + mask = self.sreg_maps[sim.size(2)].repeat(1,num_heads,1,1) + size_reg = self.reg_sizes[sim.size(2)].repeat(1,num_heads,1,1) + elif self.attention_type == "FullyFrameAttention": + mask = self.sreg_maps[sim.size(2)//self.clip_length].repeat(1,num_heads,1,1) + size_reg = self.reg_sizes[sim.size(2)//self.clip_length].repeat(1,num_heads,1,1) + elif self.attention_type == "FullyFrameAttention_sliced_attn": + mask = self.sreg_maps[sim.size(2)//self.clip_length] + size_reg = self.reg_sizes[sim.size(2)//self.clip_length] + + else: + print("unknown attention type") + exit() + # if place_in_unet == "up" and res == 32: + # # h_index 11 w_index =15 + # show_self_attention_comp(sim,video=self.video,h_index=11,w_index=15,res=32,frames=self.clip_length,place_in_unet="up",step=self.cur_step) + #if place_in_unet == "up" and res == 8: + # identify_self_attention_max_min(sim,video=self.video,h_index=3,w_index=4,res=8,frames=self.clip_length,place_in_unet="up",step=self.cur_step) + + sim += (mask>0)*size_reg*self.sreg*treg*(max_value-sim) + sim -= ~(mask>0)*size_reg*self.sreg*treg*(sim-min_value) + + else: + + min_value = sim.min(-1)[0].unsqueeze(-1) + max_value = sim.max(-1)[0].unsqueeze(-1) + mask = self.creg_maps[sim.size(2)].repeat(1,num_heads,1,1) + size_reg = self.reg_sizes_c[sim.size(2)].repeat(1,num_heads,1,1) + + sim += (mask>0)*size_reg*self.creg*treg*(max_value-sim) + sim -= ~(mask>0)*size_reg*self.creg*treg*(sim-min_value) + + self.count +=1 + return sim + + + + +class Attention_Record_Processor(AttentionStore, abc.ABC): + """ record ddim inversion self attention and cross attention """ + + def __init__(self, additional_attention_store: AttentionStore =None,save_self_attention: bool=True,disk_store=False): + super(Attention_Record_Processor, self).__init__( + save_self_attention=save_self_attention, + disk_store=disk_store) + self.additional_attention_store = additional_attention_store + self.attention_position_counter_dict = { + 'down_cross': 0, + 'mid_cross': 0, + 'up_cross': 0, + 'down_self': 0, + 'mid_self': 0, + 'up_self': 0, + } + + #print("Modulated Ctrl at denoising steps: ", self.step_idx) + + def update_attention_position_dict(self, current_attention_key): + self.attention_position_counter_dict[current_attention_key] +=1 + + + def forward(self, sim, is_cross: bool, place_in_unet: str,**kwargs): + super(Attention_Record_Processor, self).forward(sim, is_cross, place_in_unet,**kwargs) + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + + self.update_attention_position_dict(key) + + return sim + + + def between_steps(self): + + super().between_steps() + self.step_store = self.get_empty_store() + + self.attention_position_counter_dict = { + 'down_cross': 0, + 'mid_cross': 0, + 'up_cross': 0, + 'down_self': 0, + 'mid_self': 0, + 'up_self': 0, + } + return + + + +class ModulatedAttention_ControlEdit(AttentionStore, abc.ABC): + """Decide self or cross-attention. Call the reweighting cross attention module + + Args: + AttentionStore (_type_): ([1, 4, 8, 64, 64]) + abc (_type_): [8, 8, 1024, 77] + """ + + def __init__(self, end_step=15, total_steps=50, step_idx=None, text_cond=None, sreg_maps=None, creg_maps=None, reg_sizes=None,reg_sizes_c=None, + time_steps=None, + clip_length=None,attention_type=None, + additional_attention_store: AttentionStore =None, + save_self_attention: bool=True, + disk_store=False, + video = None, + ): + """ + Mutual self-attention control for Stable-Diffusion model + Args: + start_step: the step to start mutual self-attention control + start_layer: the layer to start mutual self-attention control + layer_idx: list of the layers to apply mutual self-attention control + step_idx: list the steps to apply mutual self-attention control + total_steps: the total number of steps + model_type: the model type, SD or SDXL + """ + super(ModulatedAttention_ControlEdit, self).__init__( + save_self_attention=save_self_attention, + disk_store=disk_store) + self.total_steps = total_steps + self.step_idx = list(range(0, end_step)) + self.total_infer_steps = 50 + self.text_cond = text_cond + self.sreg_maps = sreg_maps + self.creg_maps = creg_maps + self.reg_sizes = reg_sizes + self.reg_sizes_c = reg_sizes_c + self.clip_length = clip_length + self.attention_type = attention_type + self.sreg = .3 + self.creg = 1. + self.count = 0 + self.reg_part = .3 + self.time_steps = time_steps + self.additional_attention_store = additional_attention_store + self.attention_position_counter_dict = { + 'down_cross': 0, + 'mid_cross': 0, + 'up_cross': 0, + 'down_self': 0, + 'mid_self': 0, + 'up_self': 0, + } + self.video = video + + def update_attention_position_dict(self, current_attention_key): + self.attention_position_counter_dict[current_attention_key] +=1 + + + def forward(self, sim, is_cross: bool, place_in_unet: str,**kwargs): + super(ModulatedAttention_ControlEdit, self).forward(sim, is_cross, place_in_unet,**kwargs) + + # print("self.cur_step",self.cur_step) + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + + self.update_attention_position_dict(key) + + if self.cur_step not in self.step_idx: + return sim + + + num_heads = sim.shape[1] + if num_heads == 1: + self.attention_type == "FullyFrameAttention_sliced_attn" + + treg = torch.pow((self.time_steps[self.cur_step]-1)/1000, 5) + + + if not is_cross: + ## Modulate self-attention + min_value = sim.min(-1)[0].unsqueeze(-1) + max_value = sim.max(-1)[0].unsqueeze(-1) + + if self.attention_type == "SparseCausalAttention": + mask = self.sreg_maps[sim.size(2)].repeat(1,num_heads,1,1) + size_reg = self.reg_sizes[sim.size(2)].repeat(1,num_heads,1,1) + elif self.attention_type == "FullyFrameAttention": + mask = self.sreg_maps[sim.size(2)//self.clip_length].repeat(1,num_heads,1,1) + size_reg = self.reg_sizes[sim.size(2)//self.clip_length].repeat(1,num_heads,1,1) + elif self.attention_type == "FullyFrameAttention_sliced_attn": + mask = self.sreg_maps[sim.size(2)//self.clip_length] + size_reg = self.reg_sizes[sim.size(2)//self.clip_length] + + else: + print("unknown attention type") + exit() + + sim += (mask>0)*size_reg*self.sreg*treg*(max_value-sim) + sim -= ~(mask>0)*size_reg*self.sreg*treg*(sim-min_value) + + else: + #Modulate cross-attention + + min_value = sim.min(-1)[0].unsqueeze(-1) + max_value = sim.max(-1)[0].unsqueeze(-1) + mask = self.creg_maps[sim.size(2)].repeat(1,num_heads,1,1) + size_reg = self.reg_sizes_c[sim.size(2)].repeat(1,num_heads,1,1) + sim += (mask>0)*size_reg*self.creg*treg*(max_value-sim) + sim -= ~(mask>0)*size_reg*self.creg*treg*(sim-min_value) + self.count +=1 + return sim + + + def between_steps(self): + + super().between_steps() + self.step_store = self.get_empty_store() + + self.attention_position_counter_dict = { + 'down_cross': 0, + 'mid_cross': 0, + 'up_cross': 0, + 'down_self': 0, + 'mid_self': 0, + 'up_self': 0, + } + return diff --git a/video_diffusion/prompt_attention/free_lunch_utils.py b/video_diffusion/prompt_attention/free_lunch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..98c9e95e0026ff23bb34160dcea2603850aea2ce --- /dev/null +++ b/video_diffusion/prompt_attention/free_lunch_utils.py @@ -0,0 +1,384 @@ +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.fft as fft +from diffusers.utils import is_torch_version +from diffusers.models.unet_2d_condition import logger as logger2d +from diffusers.models.unet_3d_condition import logger as logger3d +from einops import rearrange + +def isinstance_str(x: object, cls_name: str): + """ + Checks whether x has any class *named* cls_name in its ancestry. + Doesn't require access to the class's implementation. + + Useful for patching! + """ + + for _cls in x.__class__.__mro__: + if _cls.__name__ == cls_name: + return True + + return False + + +def Fourier_filter(x_in, threshold, scale): + """ + Updated Fourier filter based on: + https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 + """ + x = x_in + B, C, H, W = x.shape + + # Non-power of 2 images must be float32 + if (W & (W - 1)) != 0 or (H & (H - 1)) != 0: + x = x.to(dtype=torch.float32) + + # FFT + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + + crow, ccol = H // 2, W // 2 + mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + return x_filtered.to(dtype=x_in.dtype) + + +def register_upblock2d(model): + """ + Register UpBlock2D for UNet2DCondition. + """ + + def up_forward(self): + def forward( + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None + ): + logger2d.debug(f"in upblock2d, hidden states shape: {hidden_states.shape}") + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "UpBlockPseudo3D"): + upsample_block.forward = up_forward(upsample_block) + + +def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + """ + Register UpBlock2D with FreeU for UNet2DCondition. + """ + + def up_forward(self): + def forward( + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None + ): + logger2d.debug(f"in free upblock2d, hidden states shape: {hidden_states.shape}") + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + # for video latents + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + res_hidden_states = rearrange(res_hidden_states, "b c f h w -> (b f) c h w") + # --------------- FreeU code ----------------------- + # Only operate on the first two stages + if hidden_states.shape[1] == 1280: + hidden_mean = hidden_states.mean(1).unsqueeze(1) + B = hidden_mean.shape[0] + hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) + hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1) + #hidden_states[:,:640] = hidden_states[:,:640] * self.b1 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1) + if hidden_states.shape[1] == 640: + hidden_mean = hidden_states.mean(1).unsqueeze(1) + B = hidden_mean.shape[0] + hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) + hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1) + #hidden_states[:,:320] = hidden_states[:,:320] * self.b2 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2) + # --------------------------------------------------------- + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + res_hidden_states = rearrange(res_hidden_states, "(b f) c h w -> b c f h w", f=video_length) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "UpBlockPseudo3D"): + upsample_block.forward = up_forward(upsample_block) + setattr(upsample_block, 'b1', b1) + setattr(upsample_block, 'b2', b2) + setattr(upsample_block, 's1', s1) + setattr(upsample_block, 's2', s2) + + +def register_crossattn_upblock2d(model): + """ + Register CrossAttn UpBlock2D for UNet2DCondition. + """ + + def up_forward(self): + def forward( + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + logger2d.debug(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}") + + #lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "CrossAttnUpBlockPseudo3D"): + upsample_block.forward = up_forward(upsample_block) + + +def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + """ + Register CrossAttn UpBlock2D with FreeU for UNet2DCondition. + """ + + def up_forward(self): + def forward( + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + **kwargs, + ): + logger2d.debug(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}") + + #lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + # for video latents + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + res_hidden_states = rearrange(res_hidden_states, "b c f h w -> (b f) c h w") + # --------------- FreeU code ----------------------- + # Only operate on the first two stages + if hidden_states.shape[1] == 1280: + hidden_mean = hidden_states.mean(1).unsqueeze(1) + B = hidden_mean.shape[0] + hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) + hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1) + #hidden_states[:,:640] = hidden_states[:,:640] * self.b1 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1) + if hidden_states.shape[1] == 640: + hidden_mean = hidden_states.mean(1).unsqueeze(1) + B = hidden_mean.shape[0] + hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) + hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1) + #hidden_states[:,:320] = hidden_states[:,:320] * self.b2 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2) + # --------------------------------------------------------- + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + res_hidden_states = rearrange(res_hidden_states, "(b f) c h w -> b c f h w", f=video_length) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + **kwargs + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "CrossAttnUpBlockPseudo3D"): + upsample_block.forward = up_forward(upsample_block) + setattr(upsample_block, 'b1', b1) + setattr(upsample_block, 'b2', b2) + setattr(upsample_block, 's1', s1) + setattr(upsample_block, 's2', s2) + +def apply_freeu(pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0): + register_free_upblock2d(pipe, b1, b2, s1, s2) + register_free_crossattn_upblock2d(pipe, b1, b2, s1, s2) \ No newline at end of file diff --git a/video_diffusion/prompt_attention/ptp_utils.py b/video_diffusion/prompt_attention/ptp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..506fafe02e8d45a5878115ac6c2b8fb0e20c9b2c --- /dev/null +++ b/video_diffusion/prompt_attention/ptp_utils.py @@ -0,0 +1,200 @@ +''' +utils code for image visualization +''' + + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from PIL import Image +import cv2 +from typing import Optional, Union, Tuple, List, Callable, Dict + +import datetime + + +def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): + h, w, c = image.shape + offset = int(h * .2) + img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 + font = cv2.FONT_HERSHEY_SIMPLEX + # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) + img[:h] = image + textsize = cv2.getTextSize(text, font, 1, 2)[0] + text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 + cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) + return img + + +def view_images(images, num_rows=1, offset_ratio=0.02, save_path=None): + if type(images) is list: + num_empty = len(images) % num_rows + elif images.ndim == 4: + num_empty = images.shape[0] % num_rows + else: + images = [images] + num_empty = 0 + + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty + num_items = len(images) + + h, w, c = images[0].shape + offset = int(h * offset_ratio) + num_cols = num_items // num_rows + image_ = np.ones((h * num_rows + offset * (num_rows - 1), + w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 + for i in range(num_rows): + for j in range(num_cols): + image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ + i * num_cols + j] + + if save_path is not None: + pil_img = Image.fromarray(image_) + # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + pil_img.save(f'{save_path}') + #pil_img.save(f'{save_path}/{now}.png') + # display(pil_img) + + + +def register_attention_control_p2p_deprecated(model, controller): + "Original code from prompt to prompt" + def ca_forward(self, place_in_unet): + to_out = self.to_out + if type(to_out) is torch.nn.modules.container.ModuleList: + to_out = self.to_out[0] + else: + to_out = self.to_out + + # def forward(x, encoder_hidden_states=None, attention_mask=None): + def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = self.to_q(hidden_states) + query = self.head_to_batch_dim(query) + + is_cross = encoder_hidden_states is not None + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + key = self.head_to_batch_dim(key) + value = self.head_to_batch_dim(value) + + attention_probs = self.get_attention_scores(query, key, attention_mask) # [16, 4096, 4096] + attention_probs = controller(attention_probs, is_cross, place_in_unet) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = self.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + return forward + + class DummyController: + + def __call__(self, *args): + return args[0] + + def __init__(self): + self.num_att_layers = 0 + + if controller is None: + controller = DummyController() + + def register_recr(net_, count, place_in_unet): + if net_.__class__.__name__ == 'CrossAttention': + net_.forward = ca_forward(net_, place_in_unet) + return count + 1 + elif hasattr(net_, 'children'): + for net__ in net_.children(): + count = register_recr(net__, count, place_in_unet) + return count + + cross_att_count = 0 + sub_nets = model.unet.named_children() + for net in sub_nets: + if "down" in net[0]: + cross_att_count += register_recr(net[1], 0, "down") + elif "up" in net[0]: + cross_att_count += register_recr(net[1], 0, "up") + elif "mid" in net[0]: + cross_att_count += register_recr(net[1], 0, "mid") + + controller.num_att_layers = cross_att_count + + +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if type(word_place) is str: + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif type(word_place) is int: + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + + +def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, + word_inds: Optional[torch.Tensor]=None): + # Edit the alpha map during attention map editing + if type(bounds) is float: + bounds = 0, bounds + start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) + if word_inds is None: + word_inds = torch.arange(alpha.shape[2]) + alpha[: start, prompt_ind, word_inds] = 0 + alpha[start: end, prompt_ind, word_inds] = 1 + alpha[end:, prompt_ind, word_inds] = 0 + return alpha + +import omegaconf +def get_time_words_attention_alpha(prompts, num_steps, + cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], + tokenizer, max_num_words=77): + # Not understand + if (type(cross_replace_steps) is not dict) and \ + (type(cross_replace_steps) is not omegaconf.dictconfig.DictConfig): + cross_replace_steps = {"default_": cross_replace_steps} + if "default_" not in cross_replace_steps: + cross_replace_steps["default_"] = (0., 1.) + alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) + for i in range(len(prompts) - 1): + alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], + i) + for key, item in cross_replace_steps.items(): + if key != "default_": + inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] + for i, ind in enumerate(inds): + if len(ind) > 0: + alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) + alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) + return alpha_time_words diff --git a/video_diffusion/prompt_attention/sd_study_utils.py b/video_diffusion/prompt_attention/sd_study_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a454a6273ceb5fab39b5d020e812b44dc69cce1 --- /dev/null +++ b/video_diffusion/prompt_attention/sd_study_utils.py @@ -0,0 +1,200 @@ +from typing import Optional, Union, Tuple, List, Callable, Dict +from tqdm.notebook import tqdm +import torch + +import math +from typing import List, Tuple, Union +from PIL import Image +import cv2 +import numpy as np +import os +import re +import torch +from IPython.display import display +from sklearn.cluster import KMeans +import matplotlib.pyplot as plt +from .ptp_utils import * + +import torchvision.transforms as transforms +from sklearn.decomposition import PCA +import pickle as pkl +import torch.nn.functional as F +import argparse +from sklearn.metrics.cluster import adjusted_rand_score, normalized_mutual_info_score, fowlkes_mallows_score, v_measure_score + +transform_train = transforms.Compose([ + transforms.ToPILImage(), +]) +pca = PCA(n_components=3) + +def save_mask(mask, output_name): + mask_image = transform_train(mask.float()) + mask_image.save(output_name) + + +def show_image(image): + image = 255 * image / image.max() + image = image.unsqueeze(-1).expand(*image.shape, 3) + image = image.numpy().astype(np.uint8) + image = np.array(Image.fromarray(image).resize((256, 256))) + return image + +def cluster2noun_mod(clusters, background_segment_threshold, num_segments, nouns, cross_attention): + REPEAT=clusters.shape[0]/cross_attention.shape[0] + + result = {} + result_mask={} + nouns_indices = [index for (index, word) in nouns] + nouns_maps = cross_attention.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]] + nouns_maps = cross_attention.unsqueeze(-1).cpu().numpy() + normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(REPEAT, axis=0).repeat(REPEAT, axis=1) + for i in range(nouns_maps.shape[-1]): + curr_noun_map = nouns_maps[:, :, i].repeat(REPEAT, axis=0).repeat(REPEAT, axis=1) + normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max() + for c in range(num_segments): + cluster_mask = np.zeros_like(clusters) + cluster_mask[clusters == c] = 1 + score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))] + scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps] + result[c] = nouns[np.argmax(np.array(scores))] if max(scores) > background_segment_threshold else "BG" + result_mask[c]=cluster_mask + return result, result_mask + +def cluster2noun_(clusters, background_segment_threshold, num_segments, nouns, cross_attention): + REPEAT=clusters.shape[0]/cross_attention.shape[0] + + result = {} + result_mask={} + nouns_indices = [index for (index, word) in nouns] + nouns_maps = cross_attention.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]] + normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(REPEAT, axis=0).repeat(REPEAT, axis=1) + for i in range(nouns_maps.shape[-1]): + curr_noun_map = nouns_maps[:, :, i].repeat(REPEAT, axis=0).repeat(REPEAT, axis=1) + normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max() + for c in range(num_segments): + cluster_mask = np.zeros_like(clusters) + cluster_mask[clusters == c] = 1 + score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))] + scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps] + result[c] = nouns[np.argmax(np.array(scores))] if max(scores) > background_segment_threshold else "BG" + result_mask[c]=cluster_mask + return result, result_mask + + +def aggregate_attention( attention_maps, + res: int, from_where: List[str], + is_cross: bool, select: int, prompts,): + out = [] + num_pixels = res ** 2 + for location in from_where: + for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + out = torch.cat(out, dim=0) + out = out.sum(0) / out.shape[0] + return out.cpu(), attention_maps + +def cluster(self_attention, num_segments,): + np.random.seed(1) + frames = self_attention.shape[0] + video_clusters = [] + for i in range(frames): + per_frame_attention = self_attention[i] + print('per_frame_attention',per_frame_attention.shape) + resolution, feat_dim = per_frame_attention.shape[0], per_frame_attention.shape[-1] + attn = per_frame_attention.cpu().numpy().reshape(resolution ** 2, feat_dim) + kmeans = KMeans(n_clusters=num_segments, n_init=10).fit(attn) + clusters = kmeans.labels_ + clusters = clusters.reshape(resolution, resolution) + video_clusters.append(clusters) + return video_clusters + +def run_clusters(avg_dict, resolution, dict_key, save_path, special_name, num_segments): + video_clusters = cluster(avg_dict[dict_key][resolution], num_segments,) + npy_name=f'cluster_{dict_key}_{resolution}_{special_name}.npy' + np.save(os.path.join(save_path, npy_name), video_clusters) + for i in range(len(video_clusters)): + clusters = video_clusters[i] + output_name=f'cluster_{dict_key}_{resolution}_{i}.png' + plt.imshow(clusters) + plt.axis('off') + plt.savefig(os.path.join(save_path, output_name), bbox_inches='tight', pad_inches=0) + +def read_pkl(path,): + with open(path,'rb') as f: + dict_ = pkl.load(f) + return dict_ + + +def draw_pca(avg_dict, resolution, dict_key, save_path, special_name): + + RESOLUTION=resolution + + if avg_dict[dict_key][RESOLUTION].__len__() == 0: + return + before_pca = avg_dict[dict_key][RESOLUTION] + + frames = before_pca.shape[0] + for i in range(frames): + frame = before_pca[i] + print('frame',frame.dtype) + if isinstance(frame, torch.Tensor): + frame = frame.reshape(RESOLUTION * RESOLUTION, -1).cpu().numpy() + else: + frame = frame.reshape(RESOLUTION * RESOLUTION, -1) + pca.fit(frame) + after_pca = pca.transform(frame) + + after_pca = after_pca.reshape(RESOLUTION,RESOLUTION,-1) + pca_img_min = after_pca.min(axis=(0, 1)) + pca_img_max = after_pca.max(axis=(0, 1)) + pca_img = (after_pca - pca_img_min) / (pca_img_max - pca_img_min) + + output_name=f'pca_{dict_key}_{resolution}_{i}.png' + pca_img = Image.fromarray((pca_img * 255).astype(np.uint8)) + pca_img=pca_img.resize((512,512)) + pca_img.save(os.path.join(save_path, output_name)) + +def image_normalize(numpy_array, save_path,output_name): + numpy_array=numpy_array.cpu().numpy() + img_min = numpy_array.min() + img_max = numpy_array.max() + normalize_array = (numpy_array - img_min) / (img_max - img_min) + + plt.imshow(normalize_array) + plt.axis('off') + plt.savefig(os.path.join(save_path, output_name), bbox_inches='tight', pad_inches=0) + + +def cross_cosine_with_name(resolution, inv_avg_dict, denoise_avg_dict, indice, save_path, save_crossattn=False, noun_name = ''): + inv_cross_attn = inv_avg_dict['attn'][resolution][:,:,indice] + denoise_cross_attn = denoise_avg_dict['attn'][resolution][:,:,indice] + if save_crossattn: + image_normalize(inv_cross_attn, save_path, f'crossattn_{resolution}_inv_{noun_name}.png') + image_normalize(denoise_cross_attn, save_path, f'crossattn_{resolution}_denoise_{noun_name}.png') + + return F.cosine_similarity(inv_cross_attn.reshape(1,-1), denoise_cross_attn.reshape(1,-1)) + +def cross_cosine(resolution, inv_avg_dict, denoise_avg_dict, indice, save_path, save_crossattn=False,): + inv_cross_attn = inv_avg_dict['attn'][resolution][:,:,indice] + denoise_cross_attn = denoise_avg_dict['attn'][resolution][:,:,indice] + if save_crossattn: + image_normalize(inv_cross_attn, save_path, f'crossattn_{resolution}_inv.png') + image_normalize(denoise_cross_attn, save_path, f'crossattn_{resolution}_denoise.png') + + return F.cosine_similarity(inv_cross_attn.reshape(1,-1), denoise_cross_attn.reshape(1,-1)) + +def save_crossattn(input_path, caption, inv_cross_avg_dict, denoise_cross_avg_dict, results_folder, RES=16): + org_image = Image.open(input_path).convert("RGB") + prompts=["<|startoftext|>",] + caption.split(' ') + ["<|endoftext|>",] + inv_crossattn = inv_cross_avg_dict['attn'][RES] + denoise_crossattn = denoise_cross_avg_dict['attn'][RES] + + attn_img1, mask_img1, _ = show_cross_attention_plus_orig_img(prompts, inv_crossattn, orig_image=org_image) + attn_img2, mask_img2, _ = show_cross_attention_plus_orig_img(prompts, denoise_crossattn, orig_image=org_image) + + attn_img1.save(os.path.join(results_folder,'crossattn_inv.png')) + attn_img2.save(os.path.join(results_folder,'crossattn_denoise.png')) + mask_img1.save(os.path.join(results_folder,'crossattn_inv_mask.png')) + mask_img2.save(os.path.join(results_folder,'crossattn_denoise_mask.png')) \ No newline at end of file diff --git a/video_diffusion/prompt_attention/visualization.py b/video_diffusion/prompt_attention/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..388683e01fde4779598f8468b0a3fe1f997a7d3d --- /dev/null +++ b/video_diffusion/prompt_attention/visualization.py @@ -0,0 +1,225 @@ +from typing import List +import os +import datetime +import numpy as np +from PIL import Image + +import torch + +import video_diffusion.prompt_attention.ptp_utils as ptp_utils +from video_diffusion.common.image_util import save_gif_mp4_folder_type +from video_diffusion.prompt_attention.attention_store import AttentionStore +import cv2 +from IPython.display import display +from typing import List, Tuple, Union + +def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): + out = [] + attention_maps = attention_store.get_average_attention() + num_pixels = res ** 2 + for location in from_where: + for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + #print('item',item.shape) + if item.dim() == 3: + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + elif item.dim() == 4: + t, h, res_sq, token = item.shape + if item.shape[2] == num_pixels: + cross_maps = item.reshape(len(prompts), t, -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + + out = torch.cat(out, dim=-4) + out = out.sum(-4) / out.shape[-4] + return out.cpu() + + +def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, + res: int, from_where: List[str], select: int = 0, save_path = None): + """ + attention_store (AttentionStore): + ["down", "mid", "up"] X ["self", "cross"] + 4, 1, 6 + head*res*text_token_len = 8*res*77 + res=1024 -> 64 -> 1024 + res (int): res + from_where (List[str]): "up", "down' + """ + if isinstance(prompts, str): + prompts = [prompts,] + tokens = tokenizer.encode(prompts[select]) + decoder = tokenizer.decode + + attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select) + os.makedirs('trash', exist_ok=True) + attention_list = [] + if attention_maps.dim()==3: attention_maps=attention_maps[None, ...] + for j in range(attention_maps.shape[0]): + images = [] + for i in range(len(tokens)): + image = attention_maps[j, :, :, i] + image = 255 * image / image.max() + image = image.unsqueeze(-1).expand(*image.shape, 3) + image = image.numpy().astype(np.uint8) + image = np.array(Image.fromarray(image).resize((256, 256))) + image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) + images.append(image) + ptp_utils.view_images(np.stack(images, axis=0), save_path=save_path) + atten_j = np.concatenate(images, axis=1) + attention_list.append(atten_j) + if save_path is not None: + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + video_save_path = f'{save_path}/{now}.gif' + save_gif_mp4_folder_type(attention_list, video_save_path) + return attention_list + +def tensor_to_pil(image_tensor): + # 首先确保tensor在CPU上 + image_tensor = image_tensor.cpu() + # 将C,H,W转换为H,W,C + image_tensor = image_tensor.permute(1, 2, 0) + # 正规化到[0,1] + image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) + # 转换为255范围的uint8 + image_array = np.uint8(255 * image_tensor) + # 创建PIL图像 + image_pil = Image.fromarray(image_array) + return image_pil + +def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=16): + # create heatmap from mask on image + def show_cam_on_image(img, mask): + heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) + heatmap = np.float32(heatmap) / 255 + cam = heatmap + np.float32(img) + cam = cam / np.max(cam) + return cam + image = tensor_to_pil(image) + image = image.resize((relevnace_res ** 2, relevnace_res ** 2)) + image = np.array(image) + + image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1]) + image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu + image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear') + image_relevance = image_relevance.cpu() # send it back to cpu + image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min()) + image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2) + image = (image - image.min()) / (image.max() - image.min()+1e-8) + vis = show_cam_on_image(image, image_relevance) + vis = np.uint8(255 * vis) + vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) + return vis + + + +def show_cross_attention_plus_org_img(tokenizer, prompts,org_images, attention_store: AttentionStore, + res: int, from_where: List[str], select: int = 0, save_path = None, attention_maps=None): + """ + attention_store (AttentionStore): + ["down", "mid", "up"] X ["self", "cross"] + 4, 1, 6 + head*res*text_token_len = 8*res*77 + res=1024 -> 64 -> 1024 + res (int): res + from_where (List[str]): "up", "down' + image: f c h w + """ + + + if isinstance(prompts, str): + prompts = [prompts,] + tokens = tokenizer.encode(prompts[select]) + decoder = tokenizer.decode + if attention_maps is None: + print('res',res) + attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select) + else: + attention_maps = attention_maps + os.makedirs('trash', exist_ok=True) + attention_list = [] + if attention_maps.dim()==3: attention_maps=attention_maps[None, ...] + for j in range(attention_maps.shape[0]): + images = [] + for i in range(len(tokens)): + image = attention_maps[j, :, :, i] + orig_image = org_images[j] + image = show_image_relevance(image, orig_image) + + # image = 255 * image / image.max() + # image = image.unsqueeze(-1).expand(*image.shape, 3) + image = image.astype(np.uint8) + image = np.array(Image.fromarray(image).resize((256, 256))) + image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) + images.append(image) + frame_save_path = os.path.join(save_path,f'frame_{j}_cross_attn.jpg') + ptp_utils.view_images(np.stack(images, axis=0), save_path=frame_save_path) + atten_j = np.concatenate(images, axis=1) + attention_list.append(atten_j) + if save_path is not None: + # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + video_save_path = os.path.join(save_path,'cross_attn.gif') + save_gif_mp4_folder_type(attention_list, video_save_path, save_gif=False) + return attention_list + + +def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], + max_com=10, select: int = 0): + attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2)) + u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) + images = [] + for i in range(max_com): + image = vh[i].reshape(res, res) + image = image - image.min() + image = 255 * image / image.max() + image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) + image = Image.fromarray(image).resize((256, 256)) + image = np.array(image) + images.append(image) + ptp_utils.view_images(np.concatenate(images, axis=1)) + + +def view_images(images: Union[np.ndarray, List], + num_rows: int = 1, + offset_ratio: float = 0.02, + display_image: bool = True) -> Image.Image: + """ Displays a list of images in a grid. """ + if type(images) is list: + num_empty = len(images) % num_rows + elif images.ndim == 4: + num_empty = images.shape[0] % num_rows + else: + images = [images] + num_empty = 0 + + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty + num_items = len(images) + + h, w, c = images[0].shape + offset = int(h * offset_ratio) + num_cols = num_items // num_rows + image_ = np.ones((h * num_rows + offset * (num_rows - 1), + w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 + for i in range(num_rows): + for j in range(num_cols): + image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ + i * num_cols + j] + + pil_img = Image.fromarray(image_) + if display_image: + display(pil_img) + return pil_img + + + +def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray: + h, w, c = image.shape + offset = int(h * .2) + img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 + font = cv2.FONT_HERSHEY_SIMPLEX + img[:h] = image + textsize = cv2.getTextSize(text, font, fontScale=1, thickness=2)[0] + text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 + cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) + return img