Spaces:
Runtime error
Runtime error
add utils
Browse files- utils/__init__.py +0 -0
- utils/booru_tagger.py +116 -0
- utils/constants.py +82 -0
- utils/cupy_utils.py +122 -0
- utils/effects.py +182 -0
- utils/env_utils.py +65 -0
- utils/helper_math.h +1449 -0
- utils/io_utils.py +473 -0
- utils/logger.py +20 -0
- utils/mmdet_custom_hooks.py +223 -0
utils/__init__.py
ADDED
|
File without changes
|
utils/booru_tagger.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from onnxruntime import InferenceSession
|
| 6 |
+
from typing import Tuple, List, Dict
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
def make_square(img, target_size):
|
| 16 |
+
old_size = img.shape[:2]
|
| 17 |
+
desired_size = max(old_size)
|
| 18 |
+
desired_size = max(desired_size, target_size)
|
| 19 |
+
|
| 20 |
+
delta_w = desired_size - old_size[1]
|
| 21 |
+
delta_h = desired_size - old_size[0]
|
| 22 |
+
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
|
| 23 |
+
left, right = delta_w // 2, delta_w - (delta_w // 2)
|
| 24 |
+
|
| 25 |
+
color = [255, 255, 255]
|
| 26 |
+
new_im = cv2.copyMakeBorder(
|
| 27 |
+
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
| 28 |
+
)
|
| 29 |
+
return new_im
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def smart_resize(img, size):
|
| 33 |
+
# Assumes the image has already gone through make_square
|
| 34 |
+
if img.shape[0] > size:
|
| 35 |
+
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
| 36 |
+
elif img.shape[0] < size:
|
| 37 |
+
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
|
| 38 |
+
return img
|
| 39 |
+
|
| 40 |
+
class Tagger :
|
| 41 |
+
def __init__(self, filename) -> None:
|
| 42 |
+
self.model = InferenceSession(filename, providers=['CUDAExecutionProvider'])
|
| 43 |
+
[root, _] = os.path.split(filename)
|
| 44 |
+
self.tags = pd.read_csv(os.path.join(root, 'selected_tags.csv') if root else 'selected_tags.csv')
|
| 45 |
+
|
| 46 |
+
_, self.height, _, _ = self.model.get_inputs()[0].shape
|
| 47 |
+
|
| 48 |
+
characters = self.tags.loc[self.tags['category'] == 4]
|
| 49 |
+
self.characters = set(characters['name'].values.tolist())
|
| 50 |
+
|
| 51 |
+
def label(self, image: Image) -> Dict[str, float] :
|
| 52 |
+
# alpha to white
|
| 53 |
+
image = image.convert('RGBA')
|
| 54 |
+
new_image = Image.new('RGBA', image.size, 'WHITE')
|
| 55 |
+
new_image.paste(image, mask=image)
|
| 56 |
+
image = new_image.convert('RGB')
|
| 57 |
+
image = np.asarray(image)
|
| 58 |
+
|
| 59 |
+
# PIL RGB to OpenCV BGR
|
| 60 |
+
image = image[:, :, ::-1]
|
| 61 |
+
|
| 62 |
+
image = make_square(image, self.height)
|
| 63 |
+
image = smart_resize(image, self.height)
|
| 64 |
+
image = image.astype(np.float32)
|
| 65 |
+
image = np.expand_dims(image, 0)
|
| 66 |
+
|
| 67 |
+
# evaluate model
|
| 68 |
+
input_name = self.model.get_inputs()[0].name
|
| 69 |
+
label_name = self.model.get_outputs()[0].name
|
| 70 |
+
confidents = self.model.run([label_name], {input_name: image})[0]
|
| 71 |
+
|
| 72 |
+
tags = self.tags[:][['name']]
|
| 73 |
+
tags['confidents'] = confidents[0]
|
| 74 |
+
|
| 75 |
+
# first 4 items are for rating (general, sensitive, questionable, explicit)
|
| 76 |
+
ratings = dict(tags[:4].values)
|
| 77 |
+
|
| 78 |
+
# rest are regular tags
|
| 79 |
+
tags = dict(tags[4:].values)
|
| 80 |
+
|
| 81 |
+
tags = {t: v for t, v in tags.items() if v > 0.5}
|
| 82 |
+
return tags
|
| 83 |
+
|
| 84 |
+
def label_cv2_bgr(self, image: np.ndarray) -> Dict[str, float] :
|
| 85 |
+
# image in BGR u8
|
| 86 |
+
image = make_square(image, self.height)
|
| 87 |
+
image = smart_resize(image, self.height)
|
| 88 |
+
image = image.astype(np.float32)
|
| 89 |
+
image = np.expand_dims(image, 0)
|
| 90 |
+
|
| 91 |
+
# evaluate model
|
| 92 |
+
input_name = self.model.get_inputs()[0].name
|
| 93 |
+
label_name = self.model.get_outputs()[0].name
|
| 94 |
+
confidents = self.model.run([label_name], {input_name: image})[0]
|
| 95 |
+
|
| 96 |
+
tags = self.tags[:][['name']]
|
| 97 |
+
cats = self.tags[:][['category']]
|
| 98 |
+
tags['confidents'] = confidents[0]
|
| 99 |
+
|
| 100 |
+
# first 4 items are for rating (general, sensitive, questionable, explicit)
|
| 101 |
+
ratings = dict(tags[:4].values)
|
| 102 |
+
|
| 103 |
+
# rest are regular tags
|
| 104 |
+
tags = dict(tags[4:].values)
|
| 105 |
+
|
| 106 |
+
tags = [t for t, v in tags.items() if v > 0.5]
|
| 107 |
+
character_str = []
|
| 108 |
+
for t in tags:
|
| 109 |
+
if t in self.characters:
|
| 110 |
+
character_str.append(t)
|
| 111 |
+
return tags, character_str
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == '__main__':
|
| 115 |
+
modelp = r'models/wd-v1-4-swinv2-tagger-v2/model.onnx'
|
| 116 |
+
tagger = Tagger(modelp)
|
utils/constants.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
CATEGORIES = [
|
| 4 |
+
{"id": 0, "name": "object", "isthing": 1}
|
| 5 |
+
]
|
| 6 |
+
|
| 7 |
+
IMAGE_ID_ZFILL = 12
|
| 8 |
+
|
| 9 |
+
COLOR_PALETTE = [
|
| 10 |
+
(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
|
| 11 |
+
(0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
|
| 12 |
+
(100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
|
| 13 |
+
(165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
|
| 14 |
+
(0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
|
| 15 |
+
(199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
|
| 16 |
+
(209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
|
| 17 |
+
(92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
|
| 18 |
+
(174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
|
| 19 |
+
(255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
|
| 20 |
+
(207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
|
| 21 |
+
(74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
|
| 22 |
+
(0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
|
| 23 |
+
(227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
|
| 24 |
+
(163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
|
| 25 |
+
(183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
|
| 26 |
+
(166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
|
| 27 |
+
(65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
|
| 28 |
+
(196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
|
| 29 |
+
(246, 0, 122), (191, 162, 208), (255, 255, 128), (147, 211, 203),
|
| 30 |
+
(150, 100, 100), (168, 171, 172), (146, 112, 198), (210, 170, 100),
|
| 31 |
+
(92, 136, 89), (218, 88, 184), (241, 129, 0), (217, 17, 255),
|
| 32 |
+
(124, 74, 181), (70, 70, 70), (255, 228, 255), (154, 208, 0),
|
| 33 |
+
(193, 0, 92), (76, 91, 113), (255, 180, 195), (106, 154, 176),
|
| 34 |
+
(230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
|
| 35 |
+
(254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
|
| 36 |
+
(104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
|
| 37 |
+
(135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
|
| 38 |
+
(183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
|
| 39 |
+
(146, 139, 141), (70, 130, 180), (134, 199, 156), (209, 226, 140),
|
| 40 |
+
(96, 36, 108), (96, 96, 96), (64, 170, 64), (152, 251, 152),
|
| 41 |
+
(208, 229, 228), (206, 186, 171), (152, 161, 64), (116, 112, 0),
|
| 42 |
+
(0, 114, 143), (102, 102, 156), (250, 141, 255)
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
class Colors:
|
| 46 |
+
# Ultralytics color palette https://ultralytics.com/
|
| 47 |
+
def __init__(self):
|
| 48 |
+
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
| 49 |
+
hexs = ('FF1010', '10FF10', 'FFF010', '100FFF', '0018EC', 'FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
| 50 |
+
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
| 51 |
+
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
| 52 |
+
self.n = len(self.palette)
|
| 53 |
+
|
| 54 |
+
def __call__(self, i, bgr=True):
|
| 55 |
+
c = self.palette[int(i) % self.n]
|
| 56 |
+
return (c[2], c[1], c[0]) if bgr else c
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def hex2rgb(h): # rgb order (PIL)
|
| 60 |
+
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
| 61 |
+
|
| 62 |
+
colors = Colors()
|
| 63 |
+
def get_color(idx):
|
| 64 |
+
if idx == -1:
|
| 65 |
+
return 255
|
| 66 |
+
else:
|
| 67 |
+
return colors(idx)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
MULTIPLE_TAGS = {'2girls', '3girls', '4girls', '5girls', '6+girls', 'multiple_girls',
|
| 71 |
+
'2boys', '3boys', '4boys', '5boys', '6+boys', 'multiple_boys',
|
| 72 |
+
'2others', '3others', '4others', '5others', '6+others', 'multiple_others'}
|
| 73 |
+
|
| 74 |
+
if hasattr(torch, 'cuda'):
|
| 75 |
+
DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 76 |
+
else:
|
| 77 |
+
DEFAULT_DEVICE = 'cpu'
|
| 78 |
+
|
| 79 |
+
DEFAULT_DETECTOR_CKPT = 'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
|
| 80 |
+
DEFAULT_DEPTHREFINE_CKPT = 'models/AnimeInstanceSegmentation/kenburns_depth_refinenet.ckpt'
|
| 81 |
+
DEFAULT_INPAINTNET_CKPT = 'models/AnimeInstanceSegmentation/kenburns_inpaintnet.ckpt'
|
| 82 |
+
DEPTH_ZOE_CKPT = 'models/AnimeInstanceSegmentation/ZoeD_M12_N.pt'
|
utils/cupy_utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import cupy
|
| 4 |
+
import os.path as osp
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
@cupy.memoize(for_each_device=True)
|
| 8 |
+
def launch_kernel(strFunction, strKernel):
|
| 9 |
+
if 'CUDA_HOME' not in os.environ:
|
| 10 |
+
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
| 11 |
+
# end
|
| 12 |
+
# , options=tuple([ '-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include' ])
|
| 13 |
+
return cupy.RawKernel(strKernel, strFunction)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def preprocess_kernel(strKernel, objVariables):
|
| 17 |
+
path_to_math_helper = osp.join(osp.dirname(osp.abspath(__file__)), 'helper_math.h')
|
| 18 |
+
strKernel = '''
|
| 19 |
+
#include <{{HELPER_PATH}}>
|
| 20 |
+
|
| 21 |
+
__device__ __forceinline__ float atomicMin(const float* buffer, float dblValue) {
|
| 22 |
+
int intValue = __float_as_int(*buffer);
|
| 23 |
+
|
| 24 |
+
while (__int_as_float(intValue) > dblValue) {
|
| 25 |
+
intValue = atomicCAS((int*) (buffer), intValue, __float_as_int(dblValue));
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
return __int_as_float(intValue);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
__device__ __forceinline__ float atomicMax(const float* buffer, float dblValue) {
|
| 33 |
+
int intValue = __float_as_int(*buffer);
|
| 34 |
+
|
| 35 |
+
while (__int_as_float(intValue) < dblValue) {
|
| 36 |
+
intValue = atomicCAS((int*) (buffer), intValue, __float_as_int(dblValue));
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
return __int_as_float(intValue);
|
| 40 |
+
}
|
| 41 |
+
'''.replace('{{HELPER_PATH}}', path_to_math_helper) + strKernel
|
| 42 |
+
# end
|
| 43 |
+
|
| 44 |
+
for strVariable in objVariables:
|
| 45 |
+
objValue = objVariables[strVariable]
|
| 46 |
+
|
| 47 |
+
if type(objValue) == int:
|
| 48 |
+
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
| 49 |
+
|
| 50 |
+
elif type(objValue) == float:
|
| 51 |
+
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
| 52 |
+
|
| 53 |
+
elif type(objValue) == str:
|
| 54 |
+
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
|
| 55 |
+
|
| 56 |
+
# end
|
| 57 |
+
# end
|
| 58 |
+
|
| 59 |
+
while True:
|
| 60 |
+
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
| 61 |
+
|
| 62 |
+
if objMatch is None:
|
| 63 |
+
break
|
| 64 |
+
# end
|
| 65 |
+
|
| 66 |
+
intArg = int(objMatch.group(2))
|
| 67 |
+
|
| 68 |
+
strTensor = objMatch.group(4)
|
| 69 |
+
intSizes = objVariables[strTensor].size()
|
| 70 |
+
|
| 71 |
+
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
|
| 72 |
+
# end
|
| 73 |
+
|
| 74 |
+
while True:
|
| 75 |
+
objMatch = re.search('(STRIDE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
| 76 |
+
|
| 77 |
+
if objMatch is None:
|
| 78 |
+
break
|
| 79 |
+
# end
|
| 80 |
+
|
| 81 |
+
intArg = int(objMatch.group(2))
|
| 82 |
+
|
| 83 |
+
strTensor = objMatch.group(4)
|
| 84 |
+
intStrides = objVariables[strTensor].stride()
|
| 85 |
+
|
| 86 |
+
strKernel = strKernel.replace(objMatch.group(), str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()))
|
| 87 |
+
# end
|
| 88 |
+
|
| 89 |
+
while True:
|
| 90 |
+
objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
|
| 91 |
+
|
| 92 |
+
if objMatch is None:
|
| 93 |
+
break
|
| 94 |
+
# end
|
| 95 |
+
|
| 96 |
+
intArgs = int(objMatch.group(2))
|
| 97 |
+
strArgs = objMatch.group(4).split(',')
|
| 98 |
+
|
| 99 |
+
strTensor = strArgs[0]
|
| 100 |
+
intStrides = objVariables[strTensor].stride()
|
| 101 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
|
| 102 |
+
|
| 103 |
+
strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
|
| 104 |
+
# end
|
| 105 |
+
|
| 106 |
+
while True:
|
| 107 |
+
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
|
| 108 |
+
|
| 109 |
+
if objMatch is None:
|
| 110 |
+
break
|
| 111 |
+
# end
|
| 112 |
+
|
| 113 |
+
intArgs = int(objMatch.group(2))
|
| 114 |
+
strArgs = objMatch.group(4).split(',')
|
| 115 |
+
|
| 116 |
+
strTensor = strArgs[0]
|
| 117 |
+
intStrides = objVariables[strTensor].stride()
|
| 118 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
|
| 119 |
+
|
| 120 |
+
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
|
| 121 |
+
# end
|
| 122 |
+
return strKernel
|
utils/effects.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from numba import jit, njit
|
| 2 |
+
import numpy as np
|
| 3 |
+
import time
|
| 4 |
+
import cv2
|
| 5 |
+
import math
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import os.path as osp
|
| 8 |
+
import torch
|
| 9 |
+
from .cupy_utils import launch_kernel, preprocess_kernel
|
| 10 |
+
import cupy
|
| 11 |
+
|
| 12 |
+
def bokeh_filter_cupy(img, depth, dx, dy, im_h, im_w, num_samples=32):
|
| 13 |
+
blurred = img.clone()
|
| 14 |
+
n = im_h * im_w
|
| 15 |
+
|
| 16 |
+
str_kernel = '''
|
| 17 |
+
extern "C" __global__ void kernel_bokeh(
|
| 18 |
+
const int n,
|
| 19 |
+
const int h,
|
| 20 |
+
const int w,
|
| 21 |
+
const int nsamples,
|
| 22 |
+
const float dx,
|
| 23 |
+
const float dy,
|
| 24 |
+
const float* img,
|
| 25 |
+
const float* depth,
|
| 26 |
+
float* blurred
|
| 27 |
+
) {
|
| 28 |
+
|
| 29 |
+
const int im_size = min(h, w);
|
| 30 |
+
const int sample_offset = nsamples / 2;
|
| 31 |
+
for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n * 3; intIndex += blockDim.x * gridDim.x) {
|
| 32 |
+
|
| 33 |
+
const int intSample = intIndex / 3;
|
| 34 |
+
|
| 35 |
+
const int c = intIndex % 3;
|
| 36 |
+
const int y = ( intSample / w) % h;
|
| 37 |
+
const int x = intSample % w;
|
| 38 |
+
|
| 39 |
+
const int flatten_xy = y * w + x;
|
| 40 |
+
const int fid = flatten_xy * 3 + c;
|
| 41 |
+
const float d = depth[flatten_xy];
|
| 42 |
+
|
| 43 |
+
const float _dx = dx * d;
|
| 44 |
+
const float _dy = dy * d;
|
| 45 |
+
float weight = 0;
|
| 46 |
+
float color = 0;
|
| 47 |
+
for (int s = 0; s < nsamples; s += 1) {
|
| 48 |
+
|
| 49 |
+
const int sp = (s - sample_offset) * im_size;
|
| 50 |
+
const int x_ = x + int(round(_dx * sp));
|
| 51 |
+
const int y_ = y + int(round(_dy * sp));
|
| 52 |
+
|
| 53 |
+
if ((x_ >= w) | (y_ >= h) | (x_ < 0) | (y_ < 0))
|
| 54 |
+
continue;
|
| 55 |
+
|
| 56 |
+
const int flatten_xy_ = y_ * w + x_;
|
| 57 |
+
const float w_ = depth[flatten_xy_];
|
| 58 |
+
weight += w_;
|
| 59 |
+
const int fid_ = flatten_xy_ * 3 + c;
|
| 60 |
+
color += img[fid_] * w_;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
if (weight != 0) {
|
| 64 |
+
color /= weight;
|
| 65 |
+
}
|
| 66 |
+
else {
|
| 67 |
+
color = img[fid];
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
blurred[fid] = color;
|
| 71 |
+
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
}
|
| 75 |
+
'''
|
| 76 |
+
launch_kernel('kernel_bokeh', str_kernel)(
|
| 77 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
| 78 |
+
block=tuple([ 512, 1, 1 ]),
|
| 79 |
+
args=[ cupy.int32(n), cupy.int32(im_h), cupy.int32(im_w), \
|
| 80 |
+
cupy.int32(num_samples), cupy.float32(dx), cupy.float32(dy),
|
| 81 |
+
img.data_ptr(), depth.data_ptr(), blurred.data_ptr() ]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return blurred
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def np2flatten_tensor(arr: np.ndarray, to_cuda: bool = True) -> torch.Tensor:
|
| 88 |
+
c = 1
|
| 89 |
+
if len(arr.shape) == 3:
|
| 90 |
+
c = arr.shape[2]
|
| 91 |
+
else:
|
| 92 |
+
arr = arr[..., None]
|
| 93 |
+
arr = arr.transpose((2, 0, 1))[None, ...]
|
| 94 |
+
t = torch.from_numpy(arr).view(1, c, -1)
|
| 95 |
+
|
| 96 |
+
if to_cuda:
|
| 97 |
+
t = t.cuda()
|
| 98 |
+
return t
|
| 99 |
+
|
| 100 |
+
def ftensor2img(t: torch.Tensor, im_h, im_w):
|
| 101 |
+
t = t.detach().cpu().numpy().squeeze()
|
| 102 |
+
c = t.shape[0]
|
| 103 |
+
t = t.transpose((1, 0)).reshape((im_h, im_w, c))
|
| 104 |
+
return t
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@njit
|
| 108 |
+
def bokeh_filter(img, depth, dx, dy, num_samples=32):
|
| 109 |
+
|
| 110 |
+
sample_offset = num_samples // 2
|
| 111 |
+
# _scale = 0.0005
|
| 112 |
+
# depth = depth * _scale
|
| 113 |
+
|
| 114 |
+
im_h, im_w = img.shape[0], img.shape[1]
|
| 115 |
+
im_size = min(im_h, im_w)
|
| 116 |
+
blured = np.zeros_like(img)
|
| 117 |
+
for x in range(im_w):
|
| 118 |
+
for y in range(im_h):
|
| 119 |
+
d = depth[y, x]
|
| 120 |
+
_color = np.array([0, 0, 0], dtype=np.float32)
|
| 121 |
+
_dx = dx * d
|
| 122 |
+
_dy = dy * d
|
| 123 |
+
weight = 0
|
| 124 |
+
for s in range(num_samples):
|
| 125 |
+
s = (s - sample_offset) * im_size
|
| 126 |
+
x_ = x + int(round(_dx * s))
|
| 127 |
+
y_ = y + int(round(_dy * s))
|
| 128 |
+
if x_ >= im_w or y_ >= im_h or x_ < 0 or y_ < 0:
|
| 129 |
+
continue
|
| 130 |
+
_w = depth[y_, x_]
|
| 131 |
+
weight += _w
|
| 132 |
+
_color += img[y_, x_] * _w
|
| 133 |
+
if weight == 0:
|
| 134 |
+
blured[y, x] = img[y, x]
|
| 135 |
+
else:
|
| 136 |
+
blured[y, x] = _color / np.array([weight, weight, weight], dtype=np.float32)
|
| 137 |
+
|
| 138 |
+
return blured
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def bokeh_blur(img, depth, num_samples=32, lightness_factor=10, depth_factor=2, use_cuda=False, focal_plane=None):
|
| 144 |
+
img = np.ascontiguousarray(img)
|
| 145 |
+
|
| 146 |
+
if depth is not None:
|
| 147 |
+
depth = depth.astype(np.float32)
|
| 148 |
+
if focal_plane is not None:
|
| 149 |
+
depth = depth.max() - np.abs(depth - focal_plane)
|
| 150 |
+
if depth_factor != 1:
|
| 151 |
+
depth = np.power(depth, depth_factor)
|
| 152 |
+
depth = depth - depth.min()
|
| 153 |
+
depth = depth.astype(np.float32) / depth.max()
|
| 154 |
+
depth = 1 - depth
|
| 155 |
+
|
| 156 |
+
img = img.astype(np.float32) / 255
|
| 157 |
+
img_hightlighted = np.power(img, lightness_factor)
|
| 158 |
+
|
| 159 |
+
# img =
|
| 160 |
+
im_h, im_w = img.shape[:2]
|
| 161 |
+
PI = math.pi
|
| 162 |
+
|
| 163 |
+
_scale = 0.0005
|
| 164 |
+
depth = depth * _scale
|
| 165 |
+
|
| 166 |
+
if use_cuda:
|
| 167 |
+
img_hightlighted = np2flatten_tensor(img_hightlighted, True)
|
| 168 |
+
depth = np2flatten_tensor(depth, True)
|
| 169 |
+
vertical_blured = bokeh_filter_cupy(img_hightlighted, depth, 0, 1, im_h, im_w, num_samples)
|
| 170 |
+
diag_blured = bokeh_filter_cupy(vertical_blured, depth, math.cos(-PI/6), math.sin(-PI/6), im_h, im_w, num_samples)
|
| 171 |
+
rhom_blur = bokeh_filter_cupy(diag_blured, depth, math.cos(-PI * 5 /6), math.sin(-PI * 5 /6), im_h, im_w, num_samples)
|
| 172 |
+
blured = (diag_blured + rhom_blur) / 2
|
| 173 |
+
blured = ftensor2img(blured, im_h, im_w)
|
| 174 |
+
else:
|
| 175 |
+
vertical_blured = bokeh_filter(img_hightlighted, depth, 0, 1, num_samples)
|
| 176 |
+
diag_blured = bokeh_filter(vertical_blured, depth, math.cos(-PI/6), math.sin(-PI/6), num_samples)
|
| 177 |
+
rhom_blur = bokeh_filter(diag_blured, depth, math.cos(-PI * 5 /6), math.sin(-PI * 5 /6), num_samples)
|
| 178 |
+
blured = (diag_blured + rhom_blur) / 2
|
| 179 |
+
blured = np.power(blured, 1 / lightness_factor)
|
| 180 |
+
blured = (blured * 255).astype(np.uint8)
|
| 181 |
+
|
| 182 |
+
return blured
|
utils/env_utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import platform
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import torch.multiprocessing as mp
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def set_multi_processing(
|
| 9 |
+
mp_start_method: str = "fork", opencv_num_threads: int = 0, distributed: bool = True
|
| 10 |
+
) -> None:
|
| 11 |
+
"""Set multi-processing related environment.
|
| 12 |
+
|
| 13 |
+
This function is refered from https://github.com/open-mmlab/mmengine/blob/main/mmengine/utils/dl_utils/setup_env.py
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
mp_start_method (str): Set the method which should be used to start
|
| 17 |
+
child processes. Defaults to 'fork'.
|
| 18 |
+
opencv_num_threads (int): Number of threads for opencv.
|
| 19 |
+
Defaults to 0.
|
| 20 |
+
distributed (bool): True if distributed environment.
|
| 21 |
+
Defaults to False.
|
| 22 |
+
""" # noqa
|
| 23 |
+
# set multi-process start method as `fork` to speed up the training
|
| 24 |
+
if platform.system() != "Windows":
|
| 25 |
+
current_method = mp.get_start_method(allow_none=True)
|
| 26 |
+
if current_method is not None and current_method != mp_start_method:
|
| 27 |
+
warnings.warn(
|
| 28 |
+
f"Multi-processing start method `{mp_start_method}` is "
|
| 29 |
+
f"different from the previous setting `{current_method}`."
|
| 30 |
+
f"It will be force set to `{mp_start_method}`. You can "
|
| 31 |
+
"change this behavior by changing `mp_start_method` in "
|
| 32 |
+
"your config."
|
| 33 |
+
)
|
| 34 |
+
mp.set_start_method(mp_start_method, force=True)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
import cv2
|
| 38 |
+
|
| 39 |
+
# disable opencv multithreading to avoid system being overloaded
|
| 40 |
+
cv2.setNumThreads(opencv_num_threads)
|
| 41 |
+
except ImportError:
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
# setup OMP threads
|
| 45 |
+
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
| 46 |
+
if "OMP_NUM_THREADS" not in os.environ and distributed:
|
| 47 |
+
omp_num_threads = 1
|
| 48 |
+
warnings.warn(
|
| 49 |
+
"Setting OMP_NUM_THREADS environment variable for each process"
|
| 50 |
+
f" to be {omp_num_threads} in default, to avoid your system "
|
| 51 |
+
"being overloaded, please further tune the variable for "
|
| 52 |
+
"optimal performance in your application as needed."
|
| 53 |
+
)
|
| 54 |
+
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
|
| 55 |
+
|
| 56 |
+
# # setup MKL threads
|
| 57 |
+
if "MKL_NUM_THREADS" not in os.environ and distributed:
|
| 58 |
+
mkl_num_threads = 1
|
| 59 |
+
warnings.warn(
|
| 60 |
+
"Setting MKL_NUM_THREADS environment variable for each process"
|
| 61 |
+
f" to be {mkl_num_threads} in default, to avoid your system "
|
| 62 |
+
"being overloaded, please further tune the variable for "
|
| 63 |
+
"optimal performance in your application as needed."
|
| 64 |
+
)
|
| 65 |
+
os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads)
|
utils/helper_math.h
ADDED
|
@@ -0,0 +1,1449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 1993-2012 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
| 5 |
+
* with this source code for terms and conditions that govern your use of
|
| 6 |
+
* this software. Any use, reproduction, disclosure, or distribution of
|
| 7 |
+
* this software and related documentation outside the terms of the EULA
|
| 8 |
+
* is strictly prohibited.
|
| 9 |
+
*
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
/*
|
| 13 |
+
* This file implements common mathematical operations on vector types
|
| 14 |
+
* (float3, float4 etc.) since these are not provided as standard by CUDA.
|
| 15 |
+
*
|
| 16 |
+
* The syntax is modeled on the Cg standard library.
|
| 17 |
+
*
|
| 18 |
+
* This is part of the Helper library includes
|
| 19 |
+
*
|
| 20 |
+
* Thanks to Linh Hah for additions and fixes.
|
| 21 |
+
*/
|
| 22 |
+
|
| 23 |
+
#ifndef HELPER_MATH_H
|
| 24 |
+
#define HELPER_MATH_H
|
| 25 |
+
|
| 26 |
+
#include "cuda_runtime.h"
|
| 27 |
+
|
| 28 |
+
typedef unsigned int uint;
|
| 29 |
+
typedef unsigned short ushort;
|
| 30 |
+
|
| 31 |
+
#ifndef __CUDACC__
|
| 32 |
+
#include <math.h>
|
| 33 |
+
|
| 34 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 35 |
+
// host implementations of CUDA functions
|
| 36 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 37 |
+
|
| 38 |
+
inline float fminf(float a, float b)
|
| 39 |
+
{
|
| 40 |
+
return a < b ? a : b;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
inline float fmaxf(float a, float b)
|
| 44 |
+
{
|
| 45 |
+
return a > b ? a : b;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
inline int max(int a, int b)
|
| 49 |
+
{
|
| 50 |
+
return a > b ? a : b;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
inline int min(int a, int b)
|
| 54 |
+
{
|
| 55 |
+
return a < b ? a : b;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
inline float rsqrtf(float x)
|
| 59 |
+
{
|
| 60 |
+
return 1.0f / sqrtf(x);
|
| 61 |
+
}
|
| 62 |
+
#endif
|
| 63 |
+
|
| 64 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
// constructors
|
| 66 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 67 |
+
|
| 68 |
+
inline __host__ __device__ float2 make_float2(float s)
|
| 69 |
+
{
|
| 70 |
+
return make_float2(s, s);
|
| 71 |
+
}
|
| 72 |
+
inline __host__ __device__ float2 make_float2(float3 a)
|
| 73 |
+
{
|
| 74 |
+
return make_float2(a.x, a.y);
|
| 75 |
+
}
|
| 76 |
+
inline __host__ __device__ float2 make_float2(int2 a)
|
| 77 |
+
{
|
| 78 |
+
return make_float2(float(a.x), float(a.y));
|
| 79 |
+
}
|
| 80 |
+
inline __host__ __device__ float2 make_float2(uint2 a)
|
| 81 |
+
{
|
| 82 |
+
return make_float2(float(a.x), float(a.y));
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
inline __host__ __device__ int2 make_int2(int s)
|
| 86 |
+
{
|
| 87 |
+
return make_int2(s, s);
|
| 88 |
+
}
|
| 89 |
+
inline __host__ __device__ int2 make_int2(int3 a)
|
| 90 |
+
{
|
| 91 |
+
return make_int2(a.x, a.y);
|
| 92 |
+
}
|
| 93 |
+
inline __host__ __device__ int2 make_int2(uint2 a)
|
| 94 |
+
{
|
| 95 |
+
return make_int2(int(a.x), int(a.y));
|
| 96 |
+
}
|
| 97 |
+
inline __host__ __device__ int2 make_int2(float2 a)
|
| 98 |
+
{
|
| 99 |
+
return make_int2(int(a.x), int(a.y));
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
inline __host__ __device__ uint2 make_uint2(uint s)
|
| 103 |
+
{
|
| 104 |
+
return make_uint2(s, s);
|
| 105 |
+
}
|
| 106 |
+
inline __host__ __device__ uint2 make_uint2(uint3 a)
|
| 107 |
+
{
|
| 108 |
+
return make_uint2(a.x, a.y);
|
| 109 |
+
}
|
| 110 |
+
inline __host__ __device__ uint2 make_uint2(int2 a)
|
| 111 |
+
{
|
| 112 |
+
return make_uint2(uint(a.x), uint(a.y));
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
inline __host__ __device__ float3 make_float3(float s)
|
| 116 |
+
{
|
| 117 |
+
return make_float3(s, s, s);
|
| 118 |
+
}
|
| 119 |
+
inline __host__ __device__ float3 make_float3(float2 a)
|
| 120 |
+
{
|
| 121 |
+
return make_float3(a.x, a.y, 0.0f);
|
| 122 |
+
}
|
| 123 |
+
inline __host__ __device__ float3 make_float3(float2 a, float s)
|
| 124 |
+
{
|
| 125 |
+
return make_float3(a.x, a.y, s);
|
| 126 |
+
}
|
| 127 |
+
inline __host__ __device__ float3 make_float3(float4 a)
|
| 128 |
+
{
|
| 129 |
+
return make_float3(a.x, a.y, a.z);
|
| 130 |
+
}
|
| 131 |
+
inline __host__ __device__ float3 make_float3(int3 a)
|
| 132 |
+
{
|
| 133 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
| 134 |
+
}
|
| 135 |
+
inline __host__ __device__ float3 make_float3(uint3 a)
|
| 136 |
+
{
|
| 137 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
inline __host__ __device__ int3 make_int3(int s)
|
| 141 |
+
{
|
| 142 |
+
return make_int3(s, s, s);
|
| 143 |
+
}
|
| 144 |
+
inline __host__ __device__ int3 make_int3(int2 a)
|
| 145 |
+
{
|
| 146 |
+
return make_int3(a.x, a.y, 0);
|
| 147 |
+
}
|
| 148 |
+
inline __host__ __device__ int3 make_int3(int2 a, int s)
|
| 149 |
+
{
|
| 150 |
+
return make_int3(a.x, a.y, s);
|
| 151 |
+
}
|
| 152 |
+
inline __host__ __device__ int3 make_int3(uint3 a)
|
| 153 |
+
{
|
| 154 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
| 155 |
+
}
|
| 156 |
+
inline __host__ __device__ int3 make_int3(float3 a)
|
| 157 |
+
{
|
| 158 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
inline __host__ __device__ uint3 make_uint3(uint s)
|
| 162 |
+
{
|
| 163 |
+
return make_uint3(s, s, s);
|
| 164 |
+
}
|
| 165 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a)
|
| 166 |
+
{
|
| 167 |
+
return make_uint3(a.x, a.y, 0);
|
| 168 |
+
}
|
| 169 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
|
| 170 |
+
{
|
| 171 |
+
return make_uint3(a.x, a.y, s);
|
| 172 |
+
}
|
| 173 |
+
inline __host__ __device__ uint3 make_uint3(uint4 a)
|
| 174 |
+
{
|
| 175 |
+
return make_uint3(a.x, a.y, a.z);
|
| 176 |
+
}
|
| 177 |
+
inline __host__ __device__ uint3 make_uint3(int3 a)
|
| 178 |
+
{
|
| 179 |
+
return make_uint3(uint(a.x), uint(a.y), uint(a.z));
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
inline __host__ __device__ float4 make_float4(float s)
|
| 183 |
+
{
|
| 184 |
+
return make_float4(s, s, s, s);
|
| 185 |
+
}
|
| 186 |
+
inline __host__ __device__ float4 make_float4(float3 a)
|
| 187 |
+
{
|
| 188 |
+
return make_float4(a.x, a.y, a.z, 0.0f);
|
| 189 |
+
}
|
| 190 |
+
inline __host__ __device__ float4 make_float4(float3 a, float w)
|
| 191 |
+
{
|
| 192 |
+
return make_float4(a.x, a.y, a.z, w);
|
| 193 |
+
}
|
| 194 |
+
inline __host__ __device__ float4 make_float4(int4 a)
|
| 195 |
+
{
|
| 196 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
| 197 |
+
}
|
| 198 |
+
inline __host__ __device__ float4 make_float4(uint4 a)
|
| 199 |
+
{
|
| 200 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
inline __host__ __device__ int4 make_int4(int s)
|
| 204 |
+
{
|
| 205 |
+
return make_int4(s, s, s, s);
|
| 206 |
+
}
|
| 207 |
+
inline __host__ __device__ int4 make_int4(int3 a)
|
| 208 |
+
{
|
| 209 |
+
return make_int4(a.x, a.y, a.z, 0);
|
| 210 |
+
}
|
| 211 |
+
inline __host__ __device__ int4 make_int4(int3 a, int w)
|
| 212 |
+
{
|
| 213 |
+
return make_int4(a.x, a.y, a.z, w);
|
| 214 |
+
}
|
| 215 |
+
inline __host__ __device__ int4 make_int4(uint4 a)
|
| 216 |
+
{
|
| 217 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
| 218 |
+
}
|
| 219 |
+
inline __host__ __device__ int4 make_int4(float4 a)
|
| 220 |
+
{
|
| 221 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
inline __host__ __device__ uint4 make_uint4(uint s)
|
| 226 |
+
{
|
| 227 |
+
return make_uint4(s, s, s, s);
|
| 228 |
+
}
|
| 229 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a)
|
| 230 |
+
{
|
| 231 |
+
return make_uint4(a.x, a.y, a.z, 0);
|
| 232 |
+
}
|
| 233 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
|
| 234 |
+
{
|
| 235 |
+
return make_uint4(a.x, a.y, a.z, w);
|
| 236 |
+
}
|
| 237 |
+
inline __host__ __device__ uint4 make_uint4(int4 a)
|
| 238 |
+
{
|
| 239 |
+
return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 243 |
+
// negate
|
| 244 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 245 |
+
|
| 246 |
+
inline __host__ __device__ float2 operator-(float2 &a)
|
| 247 |
+
{
|
| 248 |
+
return make_float2(-a.x, -a.y);
|
| 249 |
+
}
|
| 250 |
+
inline __host__ __device__ int2 operator-(int2 &a)
|
| 251 |
+
{
|
| 252 |
+
return make_int2(-a.x, -a.y);
|
| 253 |
+
}
|
| 254 |
+
inline __host__ __device__ float3 operator-(float3 &a)
|
| 255 |
+
{
|
| 256 |
+
return make_float3(-a.x, -a.y, -a.z);
|
| 257 |
+
}
|
| 258 |
+
inline __host__ __device__ int3 operator-(int3 &a)
|
| 259 |
+
{
|
| 260 |
+
return make_int3(-a.x, -a.y, -a.z);
|
| 261 |
+
}
|
| 262 |
+
inline __host__ __device__ float4 operator-(float4 &a)
|
| 263 |
+
{
|
| 264 |
+
return make_float4(-a.x, -a.y, -a.z, -a.w);
|
| 265 |
+
}
|
| 266 |
+
inline __host__ __device__ int4 operator-(int4 &a)
|
| 267 |
+
{
|
| 268 |
+
return make_int4(-a.x, -a.y, -a.z, -a.w);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 272 |
+
// addition
|
| 273 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 274 |
+
|
| 275 |
+
inline __host__ __device__ float2 operator+(float2 a, float2 b)
|
| 276 |
+
{
|
| 277 |
+
return make_float2(a.x + b.x, a.y + b.y);
|
| 278 |
+
}
|
| 279 |
+
inline __host__ __device__ void operator+=(float2 &a, float2 b)
|
| 280 |
+
{
|
| 281 |
+
a.x += b.x;
|
| 282 |
+
a.y += b.y;
|
| 283 |
+
}
|
| 284 |
+
inline __host__ __device__ float2 operator+(float2 a, float b)
|
| 285 |
+
{
|
| 286 |
+
return make_float2(a.x + b, a.y + b);
|
| 287 |
+
}
|
| 288 |
+
inline __host__ __device__ float2 operator+(float b, float2 a)
|
| 289 |
+
{
|
| 290 |
+
return make_float2(a.x + b, a.y + b);
|
| 291 |
+
}
|
| 292 |
+
inline __host__ __device__ void operator+=(float2 &a, float b)
|
| 293 |
+
{
|
| 294 |
+
a.x += b;
|
| 295 |
+
a.y += b;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
inline __host__ __device__ int2 operator+(int2 a, int2 b)
|
| 299 |
+
{
|
| 300 |
+
return make_int2(a.x + b.x, a.y + b.y);
|
| 301 |
+
}
|
| 302 |
+
inline __host__ __device__ void operator+=(int2 &a, int2 b)
|
| 303 |
+
{
|
| 304 |
+
a.x += b.x;
|
| 305 |
+
a.y += b.y;
|
| 306 |
+
}
|
| 307 |
+
inline __host__ __device__ int2 operator+(int2 a, int b)
|
| 308 |
+
{
|
| 309 |
+
return make_int2(a.x + b, a.y + b);
|
| 310 |
+
}
|
| 311 |
+
inline __host__ __device__ int2 operator+(int b, int2 a)
|
| 312 |
+
{
|
| 313 |
+
return make_int2(a.x + b, a.y + b);
|
| 314 |
+
}
|
| 315 |
+
inline __host__ __device__ void operator+=(int2 &a, int b)
|
| 316 |
+
{
|
| 317 |
+
a.x += b;
|
| 318 |
+
a.y += b;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
|
| 322 |
+
{
|
| 323 |
+
return make_uint2(a.x + b.x, a.y + b.y);
|
| 324 |
+
}
|
| 325 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
|
| 326 |
+
{
|
| 327 |
+
a.x += b.x;
|
| 328 |
+
a.y += b.y;
|
| 329 |
+
}
|
| 330 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint b)
|
| 331 |
+
{
|
| 332 |
+
return make_uint2(a.x + b, a.y + b);
|
| 333 |
+
}
|
| 334 |
+
inline __host__ __device__ uint2 operator+(uint b, uint2 a)
|
| 335 |
+
{
|
| 336 |
+
return make_uint2(a.x + b, a.y + b);
|
| 337 |
+
}
|
| 338 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint b)
|
| 339 |
+
{
|
| 340 |
+
a.x += b;
|
| 341 |
+
a.y += b;
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
inline __host__ __device__ float3 operator+(float3 a, float3 b)
|
| 346 |
+
{
|
| 347 |
+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 348 |
+
}
|
| 349 |
+
inline __host__ __device__ void operator+=(float3 &a, float3 b)
|
| 350 |
+
{
|
| 351 |
+
a.x += b.x;
|
| 352 |
+
a.y += b.y;
|
| 353 |
+
a.z += b.z;
|
| 354 |
+
}
|
| 355 |
+
inline __host__ __device__ float3 operator+(float3 a, float b)
|
| 356 |
+
{
|
| 357 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
| 358 |
+
}
|
| 359 |
+
inline __host__ __device__ void operator+=(float3 &a, float b)
|
| 360 |
+
{
|
| 361 |
+
a.x += b;
|
| 362 |
+
a.y += b;
|
| 363 |
+
a.z += b;
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
inline __host__ __device__ int3 operator+(int3 a, int3 b)
|
| 367 |
+
{
|
| 368 |
+
return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 369 |
+
}
|
| 370 |
+
inline __host__ __device__ void operator+=(int3 &a, int3 b)
|
| 371 |
+
{
|
| 372 |
+
a.x += b.x;
|
| 373 |
+
a.y += b.y;
|
| 374 |
+
a.z += b.z;
|
| 375 |
+
}
|
| 376 |
+
inline __host__ __device__ int3 operator+(int3 a, int b)
|
| 377 |
+
{
|
| 378 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
| 379 |
+
}
|
| 380 |
+
inline __host__ __device__ void operator+=(int3 &a, int b)
|
| 381 |
+
{
|
| 382 |
+
a.x += b;
|
| 383 |
+
a.y += b;
|
| 384 |
+
a.z += b;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
|
| 388 |
+
{
|
| 389 |
+
return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 390 |
+
}
|
| 391 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
|
| 392 |
+
{
|
| 393 |
+
a.x += b.x;
|
| 394 |
+
a.y += b.y;
|
| 395 |
+
a.z += b.z;
|
| 396 |
+
}
|
| 397 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint b)
|
| 398 |
+
{
|
| 399 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
| 400 |
+
}
|
| 401 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint b)
|
| 402 |
+
{
|
| 403 |
+
a.x += b;
|
| 404 |
+
a.y += b;
|
| 405 |
+
a.z += b;
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
inline __host__ __device__ int3 operator+(int b, int3 a)
|
| 409 |
+
{
|
| 410 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
| 411 |
+
}
|
| 412 |
+
inline __host__ __device__ uint3 operator+(uint b, uint3 a)
|
| 413 |
+
{
|
| 414 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
| 415 |
+
}
|
| 416 |
+
inline __host__ __device__ float3 operator+(float b, float3 a)
|
| 417 |
+
{
|
| 418 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
inline __host__ __device__ float4 operator+(float4 a, float4 b)
|
| 422 |
+
{
|
| 423 |
+
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 424 |
+
}
|
| 425 |
+
inline __host__ __device__ void operator+=(float4 &a, float4 b)
|
| 426 |
+
{
|
| 427 |
+
a.x += b.x;
|
| 428 |
+
a.y += b.y;
|
| 429 |
+
a.z += b.z;
|
| 430 |
+
a.w += b.w;
|
| 431 |
+
}
|
| 432 |
+
inline __host__ __device__ float4 operator+(float4 a, float b)
|
| 433 |
+
{
|
| 434 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 435 |
+
}
|
| 436 |
+
inline __host__ __device__ float4 operator+(float b, float4 a)
|
| 437 |
+
{
|
| 438 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 439 |
+
}
|
| 440 |
+
inline __host__ __device__ void operator+=(float4 &a, float b)
|
| 441 |
+
{
|
| 442 |
+
a.x += b;
|
| 443 |
+
a.y += b;
|
| 444 |
+
a.z += b;
|
| 445 |
+
a.w += b;
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
inline __host__ __device__ int4 operator+(int4 a, int4 b)
|
| 449 |
+
{
|
| 450 |
+
return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 451 |
+
}
|
| 452 |
+
inline __host__ __device__ void operator+=(int4 &a, int4 b)
|
| 453 |
+
{
|
| 454 |
+
a.x += b.x;
|
| 455 |
+
a.y += b.y;
|
| 456 |
+
a.z += b.z;
|
| 457 |
+
a.w += b.w;
|
| 458 |
+
}
|
| 459 |
+
inline __host__ __device__ int4 operator+(int4 a, int b)
|
| 460 |
+
{
|
| 461 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 462 |
+
}
|
| 463 |
+
inline __host__ __device__ int4 operator+(int b, int4 a)
|
| 464 |
+
{
|
| 465 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 466 |
+
}
|
| 467 |
+
inline __host__ __device__ void operator+=(int4 &a, int b)
|
| 468 |
+
{
|
| 469 |
+
a.x += b;
|
| 470 |
+
a.y += b;
|
| 471 |
+
a.z += b;
|
| 472 |
+
a.w += b;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
|
| 476 |
+
{
|
| 477 |
+
return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 478 |
+
}
|
| 479 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
|
| 480 |
+
{
|
| 481 |
+
a.x += b.x;
|
| 482 |
+
a.y += b.y;
|
| 483 |
+
a.z += b.z;
|
| 484 |
+
a.w += b.w;
|
| 485 |
+
}
|
| 486 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint b)
|
| 487 |
+
{
|
| 488 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 489 |
+
}
|
| 490 |
+
inline __host__ __device__ uint4 operator+(uint b, uint4 a)
|
| 491 |
+
{
|
| 492 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 493 |
+
}
|
| 494 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint b)
|
| 495 |
+
{
|
| 496 |
+
a.x += b;
|
| 497 |
+
a.y += b;
|
| 498 |
+
a.z += b;
|
| 499 |
+
a.w += b;
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 503 |
+
// subtract
|
| 504 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 505 |
+
|
| 506 |
+
inline __host__ __device__ float2 operator-(float2 a, float2 b)
|
| 507 |
+
{
|
| 508 |
+
return make_float2(a.x - b.x, a.y - b.y);
|
| 509 |
+
}
|
| 510 |
+
inline __host__ __device__ void operator-=(float2 &a, float2 b)
|
| 511 |
+
{
|
| 512 |
+
a.x -= b.x;
|
| 513 |
+
a.y -= b.y;
|
| 514 |
+
}
|
| 515 |
+
inline __host__ __device__ float2 operator-(float2 a, float b)
|
| 516 |
+
{
|
| 517 |
+
return make_float2(a.x - b, a.y - b);
|
| 518 |
+
}
|
| 519 |
+
inline __host__ __device__ float2 operator-(float b, float2 a)
|
| 520 |
+
{
|
| 521 |
+
return make_float2(b - a.x, b - a.y);
|
| 522 |
+
}
|
| 523 |
+
inline __host__ __device__ void operator-=(float2 &a, float b)
|
| 524 |
+
{
|
| 525 |
+
a.x -= b;
|
| 526 |
+
a.y -= b;
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
inline __host__ __device__ int2 operator-(int2 a, int2 b)
|
| 530 |
+
{
|
| 531 |
+
return make_int2(a.x - b.x, a.y - b.y);
|
| 532 |
+
}
|
| 533 |
+
inline __host__ __device__ void operator-=(int2 &a, int2 b)
|
| 534 |
+
{
|
| 535 |
+
a.x -= b.x;
|
| 536 |
+
a.y -= b.y;
|
| 537 |
+
}
|
| 538 |
+
inline __host__ __device__ int2 operator-(int2 a, int b)
|
| 539 |
+
{
|
| 540 |
+
return make_int2(a.x - b, a.y - b);
|
| 541 |
+
}
|
| 542 |
+
inline __host__ __device__ int2 operator-(int b, int2 a)
|
| 543 |
+
{
|
| 544 |
+
return make_int2(b - a.x, b - a.y);
|
| 545 |
+
}
|
| 546 |
+
inline __host__ __device__ void operator-=(int2 &a, int b)
|
| 547 |
+
{
|
| 548 |
+
a.x -= b;
|
| 549 |
+
a.y -= b;
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
|
| 553 |
+
{
|
| 554 |
+
return make_uint2(a.x - b.x, a.y - b.y);
|
| 555 |
+
}
|
| 556 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
|
| 557 |
+
{
|
| 558 |
+
a.x -= b.x;
|
| 559 |
+
a.y -= b.y;
|
| 560 |
+
}
|
| 561 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint b)
|
| 562 |
+
{
|
| 563 |
+
return make_uint2(a.x - b, a.y - b);
|
| 564 |
+
}
|
| 565 |
+
inline __host__ __device__ uint2 operator-(uint b, uint2 a)
|
| 566 |
+
{
|
| 567 |
+
return make_uint2(b - a.x, b - a.y);
|
| 568 |
+
}
|
| 569 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint b)
|
| 570 |
+
{
|
| 571 |
+
a.x -= b;
|
| 572 |
+
a.y -= b;
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
inline __host__ __device__ float3 operator-(float3 a, float3 b)
|
| 576 |
+
{
|
| 577 |
+
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
|
| 578 |
+
}
|
| 579 |
+
inline __host__ __device__ void operator-=(float3 &a, float3 b)
|
| 580 |
+
{
|
| 581 |
+
a.x -= b.x;
|
| 582 |
+
a.y -= b.y;
|
| 583 |
+
a.z -= b.z;
|
| 584 |
+
}
|
| 585 |
+
inline __host__ __device__ float3 operator-(float3 a, float b)
|
| 586 |
+
{
|
| 587 |
+
return make_float3(a.x - b, a.y - b, a.z - b);
|
| 588 |
+
}
|
| 589 |
+
inline __host__ __device__ float3 operator-(float b, float3 a)
|
| 590 |
+
{
|
| 591 |
+
return make_float3(b - a.x, b - a.y, b - a.z);
|
| 592 |
+
}
|
| 593 |
+
inline __host__ __device__ void operator-=(float3 &a, float b)
|
| 594 |
+
{
|
| 595 |
+
a.x -= b;
|
| 596 |
+
a.y -= b;
|
| 597 |
+
a.z -= b;
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
inline __host__ __device__ int3 operator-(int3 a, int3 b)
|
| 601 |
+
{
|
| 602 |
+
return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
|
| 603 |
+
}
|
| 604 |
+
inline __host__ __device__ void operator-=(int3 &a, int3 b)
|
| 605 |
+
{
|
| 606 |
+
a.x -= b.x;
|
| 607 |
+
a.y -= b.y;
|
| 608 |
+
a.z -= b.z;
|
| 609 |
+
}
|
| 610 |
+
inline __host__ __device__ int3 operator-(int3 a, int b)
|
| 611 |
+
{
|
| 612 |
+
return make_int3(a.x - b, a.y - b, a.z - b);
|
| 613 |
+
}
|
| 614 |
+
inline __host__ __device__ int3 operator-(int b, int3 a)
|
| 615 |
+
{
|
| 616 |
+
return make_int3(b - a.x, b - a.y, b - a.z);
|
| 617 |
+
}
|
| 618 |
+
inline __host__ __device__ void operator-=(int3 &a, int b)
|
| 619 |
+
{
|
| 620 |
+
a.x -= b;
|
| 621 |
+
a.y -= b;
|
| 622 |
+
a.z -= b;
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
|
| 626 |
+
{
|
| 627 |
+
return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
|
| 628 |
+
}
|
| 629 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
|
| 630 |
+
{
|
| 631 |
+
a.x -= b.x;
|
| 632 |
+
a.y -= b.y;
|
| 633 |
+
a.z -= b.z;
|
| 634 |
+
}
|
| 635 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint b)
|
| 636 |
+
{
|
| 637 |
+
return make_uint3(a.x - b, a.y - b, a.z - b);
|
| 638 |
+
}
|
| 639 |
+
inline __host__ __device__ uint3 operator-(uint b, uint3 a)
|
| 640 |
+
{
|
| 641 |
+
return make_uint3(b - a.x, b - a.y, b - a.z);
|
| 642 |
+
}
|
| 643 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint b)
|
| 644 |
+
{
|
| 645 |
+
a.x -= b;
|
| 646 |
+
a.y -= b;
|
| 647 |
+
a.z -= b;
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
inline __host__ __device__ float4 operator-(float4 a, float4 b)
|
| 651 |
+
{
|
| 652 |
+
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
| 653 |
+
}
|
| 654 |
+
inline __host__ __device__ void operator-=(float4 &a, float4 b)
|
| 655 |
+
{
|
| 656 |
+
a.x -= b.x;
|
| 657 |
+
a.y -= b.y;
|
| 658 |
+
a.z -= b.z;
|
| 659 |
+
a.w -= b.w;
|
| 660 |
+
}
|
| 661 |
+
inline __host__ __device__ float4 operator-(float4 a, float b)
|
| 662 |
+
{
|
| 663 |
+
return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
|
| 664 |
+
}
|
| 665 |
+
inline __host__ __device__ void operator-=(float4 &a, float b)
|
| 666 |
+
{
|
| 667 |
+
a.x -= b;
|
| 668 |
+
a.y -= b;
|
| 669 |
+
a.z -= b;
|
| 670 |
+
a.w -= b;
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
inline __host__ __device__ int4 operator-(int4 a, int4 b)
|
| 674 |
+
{
|
| 675 |
+
return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
| 676 |
+
}
|
| 677 |
+
inline __host__ __device__ void operator-=(int4 &a, int4 b)
|
| 678 |
+
{
|
| 679 |
+
a.x -= b.x;
|
| 680 |
+
a.y -= b.y;
|
| 681 |
+
a.z -= b.z;
|
| 682 |
+
a.w -= b.w;
|
| 683 |
+
}
|
| 684 |
+
inline __host__ __device__ int4 operator-(int4 a, int b)
|
| 685 |
+
{
|
| 686 |
+
return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
|
| 687 |
+
}
|
| 688 |
+
inline __host__ __device__ int4 operator-(int b, int4 a)
|
| 689 |
+
{
|
| 690 |
+
return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
|
| 691 |
+
}
|
| 692 |
+
inline __host__ __device__ void operator-=(int4 &a, int b)
|
| 693 |
+
{
|
| 694 |
+
a.x -= b;
|
| 695 |
+
a.y -= b;
|
| 696 |
+
a.z -= b;
|
| 697 |
+
a.w -= b;
|
| 698 |
+
}
|
| 699 |
+
|
| 700 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
|
| 701 |
+
{
|
| 702 |
+
return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
| 703 |
+
}
|
| 704 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
|
| 705 |
+
{
|
| 706 |
+
a.x -= b.x;
|
| 707 |
+
a.y -= b.y;
|
| 708 |
+
a.z -= b.z;
|
| 709 |
+
a.w -= b.w;
|
| 710 |
+
}
|
| 711 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint b)
|
| 712 |
+
{
|
| 713 |
+
return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
|
| 714 |
+
}
|
| 715 |
+
inline __host__ __device__ uint4 operator-(uint b, uint4 a)
|
| 716 |
+
{
|
| 717 |
+
return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
|
| 718 |
+
}
|
| 719 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint b)
|
| 720 |
+
{
|
| 721 |
+
a.x -= b;
|
| 722 |
+
a.y -= b;
|
| 723 |
+
a.z -= b;
|
| 724 |
+
a.w -= b;
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 728 |
+
// multiply
|
| 729 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 730 |
+
|
| 731 |
+
inline __host__ __device__ float2 operator*(float2 a, float2 b)
|
| 732 |
+
{
|
| 733 |
+
return make_float2(a.x * b.x, a.y * b.y);
|
| 734 |
+
}
|
| 735 |
+
inline __host__ __device__ void operator*=(float2 &a, float2 b)
|
| 736 |
+
{
|
| 737 |
+
a.x *= b.x;
|
| 738 |
+
a.y *= b.y;
|
| 739 |
+
}
|
| 740 |
+
inline __host__ __device__ float2 operator*(float2 a, float b)
|
| 741 |
+
{
|
| 742 |
+
return make_float2(a.x * b, a.y * b);
|
| 743 |
+
}
|
| 744 |
+
inline __host__ __device__ float2 operator*(float b, float2 a)
|
| 745 |
+
{
|
| 746 |
+
return make_float2(b * a.x, b * a.y);
|
| 747 |
+
}
|
| 748 |
+
inline __host__ __device__ void operator*=(float2 &a, float b)
|
| 749 |
+
{
|
| 750 |
+
a.x *= b;
|
| 751 |
+
a.y *= b;
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
inline __host__ __device__ int2 operator*(int2 a, int2 b)
|
| 755 |
+
{
|
| 756 |
+
return make_int2(a.x * b.x, a.y * b.y);
|
| 757 |
+
}
|
| 758 |
+
inline __host__ __device__ void operator*=(int2 &a, int2 b)
|
| 759 |
+
{
|
| 760 |
+
a.x *= b.x;
|
| 761 |
+
a.y *= b.y;
|
| 762 |
+
}
|
| 763 |
+
inline __host__ __device__ int2 operator*(int2 a, int b)
|
| 764 |
+
{
|
| 765 |
+
return make_int2(a.x * b, a.y * b);
|
| 766 |
+
}
|
| 767 |
+
inline __host__ __device__ int2 operator*(int b, int2 a)
|
| 768 |
+
{
|
| 769 |
+
return make_int2(b * a.x, b * a.y);
|
| 770 |
+
}
|
| 771 |
+
inline __host__ __device__ void operator*=(int2 &a, int b)
|
| 772 |
+
{
|
| 773 |
+
a.x *= b;
|
| 774 |
+
a.y *= b;
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
|
| 778 |
+
{
|
| 779 |
+
return make_uint2(a.x * b.x, a.y * b.y);
|
| 780 |
+
}
|
| 781 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
|
| 782 |
+
{
|
| 783 |
+
a.x *= b.x;
|
| 784 |
+
a.y *= b.y;
|
| 785 |
+
}
|
| 786 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint b)
|
| 787 |
+
{
|
| 788 |
+
return make_uint2(a.x * b, a.y * b);
|
| 789 |
+
}
|
| 790 |
+
inline __host__ __device__ uint2 operator*(uint b, uint2 a)
|
| 791 |
+
{
|
| 792 |
+
return make_uint2(b * a.x, b * a.y);
|
| 793 |
+
}
|
| 794 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint b)
|
| 795 |
+
{
|
| 796 |
+
a.x *= b;
|
| 797 |
+
a.y *= b;
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
inline __host__ __device__ float3 operator*(float3 a, float3 b)
|
| 801 |
+
{
|
| 802 |
+
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
|
| 803 |
+
}
|
| 804 |
+
inline __host__ __device__ void operator*=(float3 &a, float3 b)
|
| 805 |
+
{
|
| 806 |
+
a.x *= b.x;
|
| 807 |
+
a.y *= b.y;
|
| 808 |
+
a.z *= b.z;
|
| 809 |
+
}
|
| 810 |
+
inline __host__ __device__ float3 operator*(float3 a, float b)
|
| 811 |
+
{
|
| 812 |
+
return make_float3(a.x * b, a.y * b, a.z * b);
|
| 813 |
+
}
|
| 814 |
+
inline __host__ __device__ float3 operator*(float b, float3 a)
|
| 815 |
+
{
|
| 816 |
+
return make_float3(b * a.x, b * a.y, b * a.z);
|
| 817 |
+
}
|
| 818 |
+
inline __host__ __device__ void operator*=(float3 &a, float b)
|
| 819 |
+
{
|
| 820 |
+
a.x *= b;
|
| 821 |
+
a.y *= b;
|
| 822 |
+
a.z *= b;
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
inline __host__ __device__ int3 operator*(int3 a, int3 b)
|
| 826 |
+
{
|
| 827 |
+
return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
|
| 828 |
+
}
|
| 829 |
+
inline __host__ __device__ void operator*=(int3 &a, int3 b)
|
| 830 |
+
{
|
| 831 |
+
a.x *= b.x;
|
| 832 |
+
a.y *= b.y;
|
| 833 |
+
a.z *= b.z;
|
| 834 |
+
}
|
| 835 |
+
inline __host__ __device__ int3 operator*(int3 a, int b)
|
| 836 |
+
{
|
| 837 |
+
return make_int3(a.x * b, a.y * b, a.z * b);
|
| 838 |
+
}
|
| 839 |
+
inline __host__ __device__ int3 operator*(int b, int3 a)
|
| 840 |
+
{
|
| 841 |
+
return make_int3(b * a.x, b * a.y, b * a.z);
|
| 842 |
+
}
|
| 843 |
+
inline __host__ __device__ void operator*=(int3 &a, int b)
|
| 844 |
+
{
|
| 845 |
+
a.x *= b;
|
| 846 |
+
a.y *= b;
|
| 847 |
+
a.z *= b;
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
|
| 851 |
+
{
|
| 852 |
+
return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
|
| 853 |
+
}
|
| 854 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
|
| 855 |
+
{
|
| 856 |
+
a.x *= b.x;
|
| 857 |
+
a.y *= b.y;
|
| 858 |
+
a.z *= b.z;
|
| 859 |
+
}
|
| 860 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint b)
|
| 861 |
+
{
|
| 862 |
+
return make_uint3(a.x * b, a.y * b, a.z * b);
|
| 863 |
+
}
|
| 864 |
+
inline __host__ __device__ uint3 operator*(uint b, uint3 a)
|
| 865 |
+
{
|
| 866 |
+
return make_uint3(b * a.x, b * a.y, b * a.z);
|
| 867 |
+
}
|
| 868 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint b)
|
| 869 |
+
{
|
| 870 |
+
a.x *= b;
|
| 871 |
+
a.y *= b;
|
| 872 |
+
a.z *= b;
|
| 873 |
+
}
|
| 874 |
+
|
| 875 |
+
inline __host__ __device__ float4 operator*(float4 a, float4 b)
|
| 876 |
+
{
|
| 877 |
+
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
| 878 |
+
}
|
| 879 |
+
inline __host__ __device__ void operator*=(float4 &a, float4 b)
|
| 880 |
+
{
|
| 881 |
+
a.x *= b.x;
|
| 882 |
+
a.y *= b.y;
|
| 883 |
+
a.z *= b.z;
|
| 884 |
+
a.w *= b.w;
|
| 885 |
+
}
|
| 886 |
+
inline __host__ __device__ float4 operator*(float4 a, float b)
|
| 887 |
+
{
|
| 888 |
+
return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
|
| 889 |
+
}
|
| 890 |
+
inline __host__ __device__ float4 operator*(float b, float4 a)
|
| 891 |
+
{
|
| 892 |
+
return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
|
| 893 |
+
}
|
| 894 |
+
inline __host__ __device__ void operator*=(float4 &a, float b)
|
| 895 |
+
{
|
| 896 |
+
a.x *= b;
|
| 897 |
+
a.y *= b;
|
| 898 |
+
a.z *= b;
|
| 899 |
+
a.w *= b;
|
| 900 |
+
}
|
| 901 |
+
|
| 902 |
+
inline __host__ __device__ int4 operator*(int4 a, int4 b)
|
| 903 |
+
{
|
| 904 |
+
return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
| 905 |
+
}
|
| 906 |
+
inline __host__ __device__ void operator*=(int4 &a, int4 b)
|
| 907 |
+
{
|
| 908 |
+
a.x *= b.x;
|
| 909 |
+
a.y *= b.y;
|
| 910 |
+
a.z *= b.z;
|
| 911 |
+
a.w *= b.w;
|
| 912 |
+
}
|
| 913 |
+
inline __host__ __device__ int4 operator*(int4 a, int b)
|
| 914 |
+
{
|
| 915 |
+
return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
|
| 916 |
+
}
|
| 917 |
+
inline __host__ __device__ int4 operator*(int b, int4 a)
|
| 918 |
+
{
|
| 919 |
+
return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
|
| 920 |
+
}
|
| 921 |
+
inline __host__ __device__ void operator*=(int4 &a, int b)
|
| 922 |
+
{
|
| 923 |
+
a.x *= b;
|
| 924 |
+
a.y *= b;
|
| 925 |
+
a.z *= b;
|
| 926 |
+
a.w *= b;
|
| 927 |
+
}
|
| 928 |
+
|
| 929 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
|
| 930 |
+
{
|
| 931 |
+
return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
| 932 |
+
}
|
| 933 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
|
| 934 |
+
{
|
| 935 |
+
a.x *= b.x;
|
| 936 |
+
a.y *= b.y;
|
| 937 |
+
a.z *= b.z;
|
| 938 |
+
a.w *= b.w;
|
| 939 |
+
}
|
| 940 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint b)
|
| 941 |
+
{
|
| 942 |
+
return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
|
| 943 |
+
}
|
| 944 |
+
inline __host__ __device__ uint4 operator*(uint b, uint4 a)
|
| 945 |
+
{
|
| 946 |
+
return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
|
| 947 |
+
}
|
| 948 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint b)
|
| 949 |
+
{
|
| 950 |
+
a.x *= b;
|
| 951 |
+
a.y *= b;
|
| 952 |
+
a.z *= b;
|
| 953 |
+
a.w *= b;
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 957 |
+
// divide
|
| 958 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 959 |
+
|
| 960 |
+
inline __host__ __device__ float2 operator/(float2 a, float2 b)
|
| 961 |
+
{
|
| 962 |
+
return make_float2(a.x / b.x, a.y / b.y);
|
| 963 |
+
}
|
| 964 |
+
inline __host__ __device__ void operator/=(float2 &a, float2 b)
|
| 965 |
+
{
|
| 966 |
+
a.x /= b.x;
|
| 967 |
+
a.y /= b.y;
|
| 968 |
+
}
|
| 969 |
+
inline __host__ __device__ float2 operator/(float2 a, float b)
|
| 970 |
+
{
|
| 971 |
+
return make_float2(a.x / b, a.y / b);
|
| 972 |
+
}
|
| 973 |
+
inline __host__ __device__ void operator/=(float2 &a, float b)
|
| 974 |
+
{
|
| 975 |
+
a.x /= b;
|
| 976 |
+
a.y /= b;
|
| 977 |
+
}
|
| 978 |
+
inline __host__ __device__ float2 operator/(float b, float2 a)
|
| 979 |
+
{
|
| 980 |
+
return make_float2(b / a.x, b / a.y);
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
inline __host__ __device__ float3 operator/(float3 a, float3 b)
|
| 984 |
+
{
|
| 985 |
+
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
|
| 986 |
+
}
|
| 987 |
+
inline __host__ __device__ void operator/=(float3 &a, float3 b)
|
| 988 |
+
{
|
| 989 |
+
a.x /= b.x;
|
| 990 |
+
a.y /= b.y;
|
| 991 |
+
a.z /= b.z;
|
| 992 |
+
}
|
| 993 |
+
inline __host__ __device__ float3 operator/(float3 a, float b)
|
| 994 |
+
{
|
| 995 |
+
return make_float3(a.x / b, a.y / b, a.z / b);
|
| 996 |
+
}
|
| 997 |
+
inline __host__ __device__ void operator/=(float3 &a, float b)
|
| 998 |
+
{
|
| 999 |
+
a.x /= b;
|
| 1000 |
+
a.y /= b;
|
| 1001 |
+
a.z /= b;
|
| 1002 |
+
}
|
| 1003 |
+
inline __host__ __device__ float3 operator/(float b, float3 a)
|
| 1004 |
+
{
|
| 1005 |
+
return make_float3(b / a.x, b / a.y, b / a.z);
|
| 1006 |
+
}
|
| 1007 |
+
|
| 1008 |
+
inline __host__ __device__ float4 operator/(float4 a, float4 b)
|
| 1009 |
+
{
|
| 1010 |
+
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
|
| 1011 |
+
}
|
| 1012 |
+
inline __host__ __device__ void operator/=(float4 &a, float4 b)
|
| 1013 |
+
{
|
| 1014 |
+
a.x /= b.x;
|
| 1015 |
+
a.y /= b.y;
|
| 1016 |
+
a.z /= b.z;
|
| 1017 |
+
a.w /= b.w;
|
| 1018 |
+
}
|
| 1019 |
+
inline __host__ __device__ float4 operator/(float4 a, float b)
|
| 1020 |
+
{
|
| 1021 |
+
return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
|
| 1022 |
+
}
|
| 1023 |
+
inline __host__ __device__ void operator/=(float4 &a, float b)
|
| 1024 |
+
{
|
| 1025 |
+
a.x /= b;
|
| 1026 |
+
a.y /= b;
|
| 1027 |
+
a.z /= b;
|
| 1028 |
+
a.w /= b;
|
| 1029 |
+
}
|
| 1030 |
+
inline __host__ __device__ float4 operator/(float b, float4 a)
|
| 1031 |
+
{
|
| 1032 |
+
return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
|
| 1033 |
+
}
|
| 1034 |
+
|
| 1035 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1036 |
+
// min
|
| 1037 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1038 |
+
|
| 1039 |
+
inline __host__ __device__ float2 fminf(float2 a, float2 b)
|
| 1040 |
+
{
|
| 1041 |
+
return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
|
| 1042 |
+
}
|
| 1043 |
+
inline __host__ __device__ float3 fminf(float3 a, float3 b)
|
| 1044 |
+
{
|
| 1045 |
+
return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
|
| 1046 |
+
}
|
| 1047 |
+
inline __host__ __device__ float4 fminf(float4 a, float4 b)
|
| 1048 |
+
{
|
| 1049 |
+
return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
|
| 1050 |
+
}
|
| 1051 |
+
|
| 1052 |
+
inline __host__ __device__ int2 min(int2 a, int2 b)
|
| 1053 |
+
{
|
| 1054 |
+
return make_int2(min(a.x,b.x), min(a.y,b.y));
|
| 1055 |
+
}
|
| 1056 |
+
inline __host__ __device__ int3 min(int3 a, int3 b)
|
| 1057 |
+
{
|
| 1058 |
+
return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
| 1059 |
+
}
|
| 1060 |
+
inline __host__ __device__ int4 min(int4 a, int4 b)
|
| 1061 |
+
{
|
| 1062 |
+
return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
| 1063 |
+
}
|
| 1064 |
+
|
| 1065 |
+
inline __host__ __device__ uint2 min(uint2 a, uint2 b)
|
| 1066 |
+
{
|
| 1067 |
+
return make_uint2(min(a.x,b.x), min(a.y,b.y));
|
| 1068 |
+
}
|
| 1069 |
+
inline __host__ __device__ uint3 min(uint3 a, uint3 b)
|
| 1070 |
+
{
|
| 1071 |
+
return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
| 1072 |
+
}
|
| 1073 |
+
inline __host__ __device__ uint4 min(uint4 a, uint4 b)
|
| 1074 |
+
{
|
| 1075 |
+
return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
| 1076 |
+
}
|
| 1077 |
+
|
| 1078 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1079 |
+
// max
|
| 1080 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1081 |
+
|
| 1082 |
+
inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
|
| 1083 |
+
{
|
| 1084 |
+
return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
|
| 1085 |
+
}
|
| 1086 |
+
inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
|
| 1087 |
+
{
|
| 1088 |
+
return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
|
| 1089 |
+
}
|
| 1090 |
+
inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
|
| 1091 |
+
{
|
| 1092 |
+
return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
|
| 1093 |
+
}
|
| 1094 |
+
|
| 1095 |
+
inline __host__ __device__ int2 max(int2 a, int2 b)
|
| 1096 |
+
{
|
| 1097 |
+
return make_int2(max(a.x,b.x), max(a.y,b.y));
|
| 1098 |
+
}
|
| 1099 |
+
inline __host__ __device__ int3 max(int3 a, int3 b)
|
| 1100 |
+
{
|
| 1101 |
+
return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
| 1102 |
+
}
|
| 1103 |
+
inline __host__ __device__ int4 max(int4 a, int4 b)
|
| 1104 |
+
{
|
| 1105 |
+
return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
| 1106 |
+
}
|
| 1107 |
+
|
| 1108 |
+
inline __host__ __device__ uint2 max(uint2 a, uint2 b)
|
| 1109 |
+
{
|
| 1110 |
+
return make_uint2(max(a.x,b.x), max(a.y,b.y));
|
| 1111 |
+
}
|
| 1112 |
+
inline __host__ __device__ uint3 max(uint3 a, uint3 b)
|
| 1113 |
+
{
|
| 1114 |
+
return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
| 1115 |
+
}
|
| 1116 |
+
inline __host__ __device__ uint4 max(uint4 a, uint4 b)
|
| 1117 |
+
{
|
| 1118 |
+
return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
| 1119 |
+
}
|
| 1120 |
+
|
| 1121 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1122 |
+
// lerp
|
| 1123 |
+
// - linear interpolation between a and b, based on value t in [0, 1] range
|
| 1124 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1125 |
+
|
| 1126 |
+
inline __device__ __host__ float lerp(float a, float b, float t)
|
| 1127 |
+
{
|
| 1128 |
+
return a + t*(b-a);
|
| 1129 |
+
}
|
| 1130 |
+
inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
|
| 1131 |
+
{
|
| 1132 |
+
return a + t*(b-a);
|
| 1133 |
+
}
|
| 1134 |
+
inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
|
| 1135 |
+
{
|
| 1136 |
+
return a + t*(b-a);
|
| 1137 |
+
}
|
| 1138 |
+
inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
|
| 1139 |
+
{
|
| 1140 |
+
return a + t*(b-a);
|
| 1141 |
+
}
|
| 1142 |
+
|
| 1143 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1144 |
+
// clamp
|
| 1145 |
+
// - clamp the value v to be in the range [a, b]
|
| 1146 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1147 |
+
|
| 1148 |
+
inline __device__ __host__ float clamp(float f, float a, float b)
|
| 1149 |
+
{
|
| 1150 |
+
return fmaxf(a, fminf(f, b));
|
| 1151 |
+
}
|
| 1152 |
+
inline __device__ __host__ int clamp(int f, int a, int b)
|
| 1153 |
+
{
|
| 1154 |
+
return max(a, min(f, b));
|
| 1155 |
+
}
|
| 1156 |
+
inline __device__ __host__ uint clamp(uint f, uint a, uint b)
|
| 1157 |
+
{
|
| 1158 |
+
return max(a, min(f, b));
|
| 1159 |
+
}
|
| 1160 |
+
|
| 1161 |
+
inline __device__ __host__ float2 clamp(float2 v, float a, float b)
|
| 1162 |
+
{
|
| 1163 |
+
return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
|
| 1164 |
+
}
|
| 1165 |
+
inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
|
| 1166 |
+
{
|
| 1167 |
+
return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
| 1168 |
+
}
|
| 1169 |
+
inline __device__ __host__ float3 clamp(float3 v, float a, float b)
|
| 1170 |
+
{
|
| 1171 |
+
return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
| 1172 |
+
}
|
| 1173 |
+
inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
|
| 1174 |
+
{
|
| 1175 |
+
return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
| 1176 |
+
}
|
| 1177 |
+
inline __device__ __host__ float4 clamp(float4 v, float a, float b)
|
| 1178 |
+
{
|
| 1179 |
+
return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
| 1180 |
+
}
|
| 1181 |
+
inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
|
| 1182 |
+
{
|
| 1183 |
+
return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
| 1184 |
+
}
|
| 1185 |
+
|
| 1186 |
+
inline __device__ __host__ int2 clamp(int2 v, int a, int b)
|
| 1187 |
+
{
|
| 1188 |
+
return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
|
| 1189 |
+
}
|
| 1190 |
+
inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
|
| 1191 |
+
{
|
| 1192 |
+
return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
| 1193 |
+
}
|
| 1194 |
+
inline __device__ __host__ int3 clamp(int3 v, int a, int b)
|
| 1195 |
+
{
|
| 1196 |
+
return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
| 1197 |
+
}
|
| 1198 |
+
inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
|
| 1199 |
+
{
|
| 1200 |
+
return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
| 1201 |
+
}
|
| 1202 |
+
inline __device__ __host__ int4 clamp(int4 v, int a, int b)
|
| 1203 |
+
{
|
| 1204 |
+
return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
| 1205 |
+
}
|
| 1206 |
+
inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
|
| 1207 |
+
{
|
| 1208 |
+
return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
| 1209 |
+
}
|
| 1210 |
+
|
| 1211 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
|
| 1212 |
+
{
|
| 1213 |
+
return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
|
| 1214 |
+
}
|
| 1215 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
|
| 1216 |
+
{
|
| 1217 |
+
return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
| 1218 |
+
}
|
| 1219 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
|
| 1220 |
+
{
|
| 1221 |
+
return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
| 1222 |
+
}
|
| 1223 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
|
| 1224 |
+
{
|
| 1225 |
+
return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
| 1226 |
+
}
|
| 1227 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
|
| 1228 |
+
{
|
| 1229 |
+
return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
| 1230 |
+
}
|
| 1231 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
|
| 1232 |
+
{
|
| 1233 |
+
return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
| 1234 |
+
}
|
| 1235 |
+
|
| 1236 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1237 |
+
// dot product
|
| 1238 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1239 |
+
|
| 1240 |
+
inline __host__ __device__ float dot(float2 a, float2 b)
|
| 1241 |
+
{
|
| 1242 |
+
return a.x * b.x + a.y * b.y;
|
| 1243 |
+
}
|
| 1244 |
+
inline __host__ __device__ float dot(float3 a, float3 b)
|
| 1245 |
+
{
|
| 1246 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 1247 |
+
}
|
| 1248 |
+
inline __host__ __device__ float dot(float4 a, float4 b)
|
| 1249 |
+
{
|
| 1250 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
| 1251 |
+
}
|
| 1252 |
+
|
| 1253 |
+
inline __host__ __device__ int dot(int2 a, int2 b)
|
| 1254 |
+
{
|
| 1255 |
+
return a.x * b.x + a.y * b.y;
|
| 1256 |
+
}
|
| 1257 |
+
inline __host__ __device__ int dot(int3 a, int3 b)
|
| 1258 |
+
{
|
| 1259 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 1260 |
+
}
|
| 1261 |
+
inline __host__ __device__ int dot(int4 a, int4 b)
|
| 1262 |
+
{
|
| 1263 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
| 1264 |
+
}
|
| 1265 |
+
|
| 1266 |
+
inline __host__ __device__ uint dot(uint2 a, uint2 b)
|
| 1267 |
+
{
|
| 1268 |
+
return a.x * b.x + a.y * b.y;
|
| 1269 |
+
}
|
| 1270 |
+
inline __host__ __device__ uint dot(uint3 a, uint3 b)
|
| 1271 |
+
{
|
| 1272 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 1273 |
+
}
|
| 1274 |
+
inline __host__ __device__ uint dot(uint4 a, uint4 b)
|
| 1275 |
+
{
|
| 1276 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
| 1277 |
+
}
|
| 1278 |
+
|
| 1279 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1280 |
+
// length
|
| 1281 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1282 |
+
|
| 1283 |
+
inline __host__ __device__ float length(float2 v)
|
| 1284 |
+
{
|
| 1285 |
+
return sqrtf(dot(v, v));
|
| 1286 |
+
}
|
| 1287 |
+
inline __host__ __device__ float length(float3 v)
|
| 1288 |
+
{
|
| 1289 |
+
return sqrtf(dot(v, v));
|
| 1290 |
+
}
|
| 1291 |
+
inline __host__ __device__ float length(float4 v)
|
| 1292 |
+
{
|
| 1293 |
+
return sqrtf(dot(v, v));
|
| 1294 |
+
}
|
| 1295 |
+
|
| 1296 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1297 |
+
// normalize
|
| 1298 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1299 |
+
|
| 1300 |
+
inline __host__ __device__ float2 normalize(float2 v)
|
| 1301 |
+
{
|
| 1302 |
+
float invLen = rsqrtf(dot(v, v));
|
| 1303 |
+
return v * invLen;
|
| 1304 |
+
}
|
| 1305 |
+
inline __host__ __device__ float3 normalize(float3 v)
|
| 1306 |
+
{
|
| 1307 |
+
float invLen = rsqrtf(dot(v, v));
|
| 1308 |
+
return v * invLen;
|
| 1309 |
+
}
|
| 1310 |
+
inline __host__ __device__ float4 normalize(float4 v)
|
| 1311 |
+
{
|
| 1312 |
+
float invLen = rsqrtf(dot(v, v));
|
| 1313 |
+
return v * invLen;
|
| 1314 |
+
}
|
| 1315 |
+
|
| 1316 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1317 |
+
// floor
|
| 1318 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1319 |
+
|
| 1320 |
+
inline __host__ __device__ float2 floorf(float2 v)
|
| 1321 |
+
{
|
| 1322 |
+
return make_float2(floorf(v.x), floorf(v.y));
|
| 1323 |
+
}
|
| 1324 |
+
inline __host__ __device__ float3 floorf(float3 v)
|
| 1325 |
+
{
|
| 1326 |
+
return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
|
| 1327 |
+
}
|
| 1328 |
+
inline __host__ __device__ float4 floorf(float4 v)
|
| 1329 |
+
{
|
| 1330 |
+
return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
|
| 1331 |
+
}
|
| 1332 |
+
|
| 1333 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1334 |
+
// frac - returns the fractional portion of a scalar or each vector component
|
| 1335 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1336 |
+
|
| 1337 |
+
inline __host__ __device__ float fracf(float v)
|
| 1338 |
+
{
|
| 1339 |
+
return v - floorf(v);
|
| 1340 |
+
}
|
| 1341 |
+
inline __host__ __device__ float2 fracf(float2 v)
|
| 1342 |
+
{
|
| 1343 |
+
return make_float2(fracf(v.x), fracf(v.y));
|
| 1344 |
+
}
|
| 1345 |
+
inline __host__ __device__ float3 fracf(float3 v)
|
| 1346 |
+
{
|
| 1347 |
+
return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
|
| 1348 |
+
}
|
| 1349 |
+
inline __host__ __device__ float4 fracf(float4 v)
|
| 1350 |
+
{
|
| 1351 |
+
return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
|
| 1352 |
+
}
|
| 1353 |
+
|
| 1354 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1355 |
+
// fmod
|
| 1356 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1357 |
+
|
| 1358 |
+
inline __host__ __device__ float2 fmodf(float2 a, float2 b)
|
| 1359 |
+
{
|
| 1360 |
+
return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
|
| 1361 |
+
}
|
| 1362 |
+
inline __host__ __device__ float3 fmodf(float3 a, float3 b)
|
| 1363 |
+
{
|
| 1364 |
+
return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
|
| 1365 |
+
}
|
| 1366 |
+
inline __host__ __device__ float4 fmodf(float4 a, float4 b)
|
| 1367 |
+
{
|
| 1368 |
+
return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
|
| 1369 |
+
}
|
| 1370 |
+
|
| 1371 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1372 |
+
// absolute value
|
| 1373 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1374 |
+
|
| 1375 |
+
inline __host__ __device__ float2 fabs(float2 v)
|
| 1376 |
+
{
|
| 1377 |
+
return make_float2(fabs(v.x), fabs(v.y));
|
| 1378 |
+
}
|
| 1379 |
+
inline __host__ __device__ float3 fabs(float3 v)
|
| 1380 |
+
{
|
| 1381 |
+
return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
|
| 1382 |
+
}
|
| 1383 |
+
inline __host__ __device__ float4 fabs(float4 v)
|
| 1384 |
+
{
|
| 1385 |
+
return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
|
| 1386 |
+
}
|
| 1387 |
+
|
| 1388 |
+
inline __host__ __device__ int2 abs(int2 v)
|
| 1389 |
+
{
|
| 1390 |
+
return make_int2(abs(v.x), abs(v.y));
|
| 1391 |
+
}
|
| 1392 |
+
inline __host__ __device__ int3 abs(int3 v)
|
| 1393 |
+
{
|
| 1394 |
+
return make_int3(abs(v.x), abs(v.y), abs(v.z));
|
| 1395 |
+
}
|
| 1396 |
+
inline __host__ __device__ int4 abs(int4 v)
|
| 1397 |
+
{
|
| 1398 |
+
return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
|
| 1399 |
+
}
|
| 1400 |
+
|
| 1401 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1402 |
+
// reflect
|
| 1403 |
+
// - returns reflection of incident ray I around surface normal N
|
| 1404 |
+
// - N should be normalized, reflected vector's length is equal to length of I
|
| 1405 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1406 |
+
|
| 1407 |
+
inline __host__ __device__ float3 reflect(float3 i, float3 n)
|
| 1408 |
+
{
|
| 1409 |
+
return i - 2.0f * n * dot(n,i);
|
| 1410 |
+
}
|
| 1411 |
+
|
| 1412 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1413 |
+
// cross product
|
| 1414 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1415 |
+
|
| 1416 |
+
inline __host__ __device__ float3 cross(float3 a, float3 b)
|
| 1417 |
+
{
|
| 1418 |
+
return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
|
| 1419 |
+
}
|
| 1420 |
+
|
| 1421 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1422 |
+
// smoothstep
|
| 1423 |
+
// - returns 0 if x < a
|
| 1424 |
+
// - returns 1 if x > b
|
| 1425 |
+
// - otherwise returns smooth interpolation between 0 and 1 based on x
|
| 1426 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1427 |
+
|
| 1428 |
+
inline __device__ __host__ float smoothstep(float a, float b, float x)
|
| 1429 |
+
{
|
| 1430 |
+
float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1431 |
+
return (y*y*(3.0f - (2.0f*y)));
|
| 1432 |
+
}
|
| 1433 |
+
inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
|
| 1434 |
+
{
|
| 1435 |
+
float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1436 |
+
return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
|
| 1437 |
+
}
|
| 1438 |
+
inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
|
| 1439 |
+
{
|
| 1440 |
+
float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1441 |
+
return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
|
| 1442 |
+
}
|
| 1443 |
+
inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
|
| 1444 |
+
{
|
| 1445 |
+
float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1446 |
+
return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
|
| 1447 |
+
}
|
| 1448 |
+
|
| 1449 |
+
#endif
|
utils/io_utils.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json, os, sys
|
| 3 |
+
import os.path as osp
|
| 4 |
+
from typing import List, Union, Tuple, Dict
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from imageio import imread, imwrite
|
| 9 |
+
import pickle
|
| 10 |
+
import pycocotools.mask as maskUtils
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import io
|
| 15 |
+
import requests
|
| 16 |
+
import traceback
|
| 17 |
+
import base64
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
NP_BOOL_TYPES = (np.bool_, np.bool8)
|
| 22 |
+
NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64)
|
| 23 |
+
NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64)
|
| 24 |
+
|
| 25 |
+
class NumpyEncoder(json.JSONEncoder):
|
| 26 |
+
def default(self, obj):
|
| 27 |
+
if isinstance(obj, np.ndarray):
|
| 28 |
+
return obj.tolist()
|
| 29 |
+
elif isinstance(obj, np.ScalarType):
|
| 30 |
+
if isinstance(obj, NP_BOOL_TYPES):
|
| 31 |
+
return bool(obj)
|
| 32 |
+
elif isinstance(obj, NP_FLOAT_TYPES):
|
| 33 |
+
return float(obj)
|
| 34 |
+
elif isinstance(obj, NP_INT_TYPES):
|
| 35 |
+
return int(obj)
|
| 36 |
+
return json.JSONEncoder.default(self, obj)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def json2dict(json_path: str):
|
| 40 |
+
with open(json_path, 'r', encoding='utf8') as f:
|
| 41 |
+
metadata = json.loads(f.read())
|
| 42 |
+
return metadata
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def dict2json(adict: dict, json_path: str):
|
| 46 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
| 47 |
+
f.write(json.dumps(adict, ensure_ascii=False, cls=NumpyEncoder))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def dict2pickle(dumped_path: str, tgt_dict: dict):
|
| 51 |
+
with open(dumped_path, "wb") as f:
|
| 52 |
+
pickle.dump(tgt_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def pickle2dict(pkl_path: str) -> Dict:
|
| 56 |
+
with open(pkl_path, "rb") as f:
|
| 57 |
+
dumped_data = pickle.load(f)
|
| 58 |
+
return dumped_data
|
| 59 |
+
|
| 60 |
+
def get_all_dirs(root_p: str) -> List[str]:
|
| 61 |
+
alldir = os.listdir(root_p)
|
| 62 |
+
dirlist = []
|
| 63 |
+
for dirp in alldir:
|
| 64 |
+
dirp = osp.join(root_p, dirp)
|
| 65 |
+
if osp.isdir(dirp):
|
| 66 |
+
dirlist.append(dirp)
|
| 67 |
+
return dirlist
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def read_filelist(filelistp: str):
|
| 71 |
+
with open(filelistp, 'r', encoding='utf8') as f:
|
| 72 |
+
lines = f.readlines()
|
| 73 |
+
if len(lines) > 0 and lines[-1].strip() == '':
|
| 74 |
+
lines = lines[:-1]
|
| 75 |
+
return lines
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
VIDEO_EXTS = {'.flv', '.mp4', '.mkv', '.ts', '.mov', 'mpeg'}
|
| 79 |
+
def get_all_videos(video_dir: str, video_exts=VIDEO_EXTS, abs_path=False) -> List[str]:
|
| 80 |
+
filelist = os.listdir(video_dir)
|
| 81 |
+
vlist = []
|
| 82 |
+
for f in filelist:
|
| 83 |
+
if Path(f).suffix in video_exts:
|
| 84 |
+
if abs_path:
|
| 85 |
+
vlist.append(osp.join(video_dir, f))
|
| 86 |
+
else:
|
| 87 |
+
vlist.append(f)
|
| 88 |
+
return vlist
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
IMG_EXT = {'.bmp', '.jpg', '.png', '.jpeg'}
|
| 92 |
+
def find_all_imgs(img_dir, abs_path=False):
|
| 93 |
+
imglist = []
|
| 94 |
+
dir_list = os.listdir(img_dir)
|
| 95 |
+
for filename in dir_list:
|
| 96 |
+
file_suffix = Path(filename).suffix
|
| 97 |
+
if file_suffix.lower() not in IMG_EXT:
|
| 98 |
+
continue
|
| 99 |
+
if abs_path:
|
| 100 |
+
imglist.append(osp.join(img_dir, filename))
|
| 101 |
+
else:
|
| 102 |
+
imglist.append(filename)
|
| 103 |
+
return imglist
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def find_all_files_recursive(tgt_dir: Union[List, str], ext, exclude_dirs={}):
|
| 107 |
+
if isinstance(tgt_dir, str):
|
| 108 |
+
tgt_dir = [tgt_dir]
|
| 109 |
+
|
| 110 |
+
filelst = []
|
| 111 |
+
for d in tgt_dir:
|
| 112 |
+
for root, _, files in os.walk(d):
|
| 113 |
+
if osp.basename(root) in exclude_dirs:
|
| 114 |
+
continue
|
| 115 |
+
for f in files:
|
| 116 |
+
if Path(f).suffix.lower() in ext:
|
| 117 |
+
filelst.append(osp.join(root, f))
|
| 118 |
+
|
| 119 |
+
return filelst
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def danbooruid2relpath(id_str: str, file_ext='.jpg'):
|
| 123 |
+
if not isinstance(id_str, str):
|
| 124 |
+
id_str = str(id_str)
|
| 125 |
+
return id_str[-3:].zfill(4) + '/' + id_str + file_ext
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_template_histvq(template: np.ndarray) -> Tuple[List[np.ndarray]]:
|
| 129 |
+
len_shape = len(template.shape)
|
| 130 |
+
num_c = 3
|
| 131 |
+
mask = None
|
| 132 |
+
if len_shape == 2:
|
| 133 |
+
num_c = 1
|
| 134 |
+
elif len_shape == 3 and template.shape[-1] == 4:
|
| 135 |
+
mask = np.where(template[..., -1])
|
| 136 |
+
template = template[..., :num_c][mask]
|
| 137 |
+
|
| 138 |
+
values, quantiles = [], []
|
| 139 |
+
for ii in range(num_c):
|
| 140 |
+
v, c = np.unique(template[..., ii].ravel(), return_counts=True)
|
| 141 |
+
q = np.cumsum(c).astype(np.float64)
|
| 142 |
+
if len(q) < 1:
|
| 143 |
+
return None, None
|
| 144 |
+
q /= q[-1]
|
| 145 |
+
values.append(v)
|
| 146 |
+
quantiles.append(q)
|
| 147 |
+
return values, quantiles
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def inplace_hist_matching(img: np.ndarray, tv: List[np.ndarray], tq: List[np.ndarray]) -> None:
|
| 151 |
+
len_shape = len(img.shape)
|
| 152 |
+
num_c = 3
|
| 153 |
+
mask = None
|
| 154 |
+
|
| 155 |
+
tgtimg = img
|
| 156 |
+
if len_shape == 2:
|
| 157 |
+
num_c = 1
|
| 158 |
+
elif len_shape == 3 and img.shape[-1] == 4:
|
| 159 |
+
mask = np.where(img[..., -1])
|
| 160 |
+
tgtimg = img[..., :num_c][mask]
|
| 161 |
+
|
| 162 |
+
im_h, im_w = img.shape[:2]
|
| 163 |
+
oldtype = img.dtype
|
| 164 |
+
for ii in range(num_c):
|
| 165 |
+
_, bin_idx, s_counts = np.unique(tgtimg[..., ii].ravel(), return_inverse=True,
|
| 166 |
+
return_counts=True)
|
| 167 |
+
s_quantiles = np.cumsum(s_counts).astype(np.float64)
|
| 168 |
+
if len(s_quantiles) == 0:
|
| 169 |
+
return
|
| 170 |
+
s_quantiles /= s_quantiles[-1]
|
| 171 |
+
interp_t_values = np.interp(s_quantiles, tq[ii], tv[ii]).astype(oldtype)
|
| 172 |
+
if mask is not None:
|
| 173 |
+
img[..., ii][mask] = interp_t_values[bin_idx]
|
| 174 |
+
else:
|
| 175 |
+
img[..., ii] = interp_t_values[bin_idx].reshape((im_h, im_w))
|
| 176 |
+
# try:
|
| 177 |
+
# img[..., ii] = interp_t_values[bin_idx].reshape((im_h, im_w))
|
| 178 |
+
# except:
|
| 179 |
+
# LOGGER.error('##################### sth goes wrong')
|
| 180 |
+
# cv2.imshow('img', img)
|
| 181 |
+
# cv2.waitKey(0)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def fgbg_hist_matching(fg_list: List, bg: np.ndarray, min_tq_num=128):
|
| 185 |
+
btv, btq = get_template_histvq(bg)
|
| 186 |
+
ftv, ftq = get_template_histvq(fg_list[0]['image'])
|
| 187 |
+
num_fg = len(fg_list)
|
| 188 |
+
idx_matched = -1
|
| 189 |
+
if num_fg > 1:
|
| 190 |
+
_ftv, _ftq = get_template_histvq(fg_list[0]['image'])
|
| 191 |
+
if _ftq is not None and ftq is not None:
|
| 192 |
+
if len(_ftq[0]) > len(ftq[0]):
|
| 193 |
+
idx_matched = num_fg - 1
|
| 194 |
+
ftv, ftq = _ftv, _ftq
|
| 195 |
+
else:
|
| 196 |
+
idx_matched = 0
|
| 197 |
+
|
| 198 |
+
if btq is not None and ftq is not None:
|
| 199 |
+
if len(btq[0]) > len(ftq[0]):
|
| 200 |
+
tv, tq = btv, btq
|
| 201 |
+
idx_matched = -1
|
| 202 |
+
else:
|
| 203 |
+
tv, tq = ftv, ftq
|
| 204 |
+
if len(tq[0]) > min_tq_num:
|
| 205 |
+
inplace_hist_matching(bg, tv, tq)
|
| 206 |
+
|
| 207 |
+
if len(tq[0]) > min_tq_num:
|
| 208 |
+
for ii, fg_dict in enumerate(fg_list):
|
| 209 |
+
fg = fg_dict['image']
|
| 210 |
+
if ii != idx_matched and len(tq[0]) > min_tq_num:
|
| 211 |
+
inplace_hist_matching(fg, tv, tq)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def imread_nogrey_rgb(imp: str) -> np.ndarray:
|
| 215 |
+
img: np.ndarray = imread(imp)
|
| 216 |
+
c = 1
|
| 217 |
+
if len(img.shape) == 3:
|
| 218 |
+
c = img.shape[-1]
|
| 219 |
+
if c == 1:
|
| 220 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 221 |
+
if c == 4:
|
| 222 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
|
| 223 |
+
return img
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def square_pad_resize(img: np.ndarray, tgt_size: int, pad_value: Tuple = (114, 114, 114)):
|
| 227 |
+
h, w = img.shape[:2]
|
| 228 |
+
pad_h, pad_w = 0, 0
|
| 229 |
+
|
| 230 |
+
# make square image
|
| 231 |
+
if w < h:
|
| 232 |
+
pad_w = h - w
|
| 233 |
+
w += pad_w
|
| 234 |
+
elif h < w:
|
| 235 |
+
pad_h = w - h
|
| 236 |
+
h += pad_h
|
| 237 |
+
|
| 238 |
+
pad_size = tgt_size - h
|
| 239 |
+
if pad_size > 0:
|
| 240 |
+
pad_h += pad_size
|
| 241 |
+
pad_w += pad_size
|
| 242 |
+
|
| 243 |
+
if pad_h > 0 or pad_w > 0:
|
| 244 |
+
img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=pad_value)
|
| 245 |
+
|
| 246 |
+
down_scale_ratio = tgt_size / img.shape[0]
|
| 247 |
+
assert down_scale_ratio <= 1
|
| 248 |
+
if down_scale_ratio < 1:
|
| 249 |
+
img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA)
|
| 250 |
+
|
| 251 |
+
return img, down_scale_ratio, pad_h, pad_w
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def scaledown_maxsize(img: np.ndarray, max_size: int, divisior: int = None):
|
| 255 |
+
|
| 256 |
+
im_h, im_w = img.shape[:2]
|
| 257 |
+
ori_h, ori_w = img.shape[:2]
|
| 258 |
+
resize_ratio = max_size / max(im_h, im_w)
|
| 259 |
+
if resize_ratio < 1:
|
| 260 |
+
if im_h > im_w:
|
| 261 |
+
im_h = max_size
|
| 262 |
+
im_w = max(1, int(round(im_w * resize_ratio)))
|
| 263 |
+
|
| 264 |
+
else:
|
| 265 |
+
im_w = max_size
|
| 266 |
+
im_h = max(1, int(round(im_h * resize_ratio)))
|
| 267 |
+
if divisior is not None:
|
| 268 |
+
im_w = int(np.ceil(im_w / divisior) * divisior)
|
| 269 |
+
im_h = int(np.ceil(im_h / divisior) * divisior)
|
| 270 |
+
|
| 271 |
+
if im_w != ori_w or im_h != ori_h:
|
| 272 |
+
img = cv2.resize(img, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
|
| 273 |
+
|
| 274 |
+
return img
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def resize_pad(img: np.ndarray, tgt_size: int, pad_value: Tuple = (0, 0, 0)):
|
| 278 |
+
# downscale to tgt_size and pad to square
|
| 279 |
+
img = scaledown_maxsize(img, tgt_size)
|
| 280 |
+
padl, padr, padt, padb = 0, 0, 0, 0
|
| 281 |
+
h, w = img.shape[:2]
|
| 282 |
+
# padt = (tgt_size - h) // 2
|
| 283 |
+
# padb = tgt_size - h - padt
|
| 284 |
+
# padl = (tgt_size - w) // 2
|
| 285 |
+
# padr = tgt_size - w - padl
|
| 286 |
+
padb = tgt_size - h
|
| 287 |
+
padr = tgt_size - w
|
| 288 |
+
|
| 289 |
+
if padt + padb + padl + padr > 0:
|
| 290 |
+
img = cv2.copyMakeBorder(img, padt, padb, padl, padr, cv2.BORDER_CONSTANT, value=pad_value)
|
| 291 |
+
|
| 292 |
+
return img, (padt, padb, padl, padr)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def resize_pad2divisior(img: np.ndarray, tgt_size: int, divisior: int = 64, pad_value: Tuple = (0, 0, 0)):
|
| 296 |
+
img = scaledown_maxsize(img, tgt_size)
|
| 297 |
+
img, (pad_h, pad_w) = pad2divisior(img, divisior, pad_value)
|
| 298 |
+
return img, (pad_h, pad_w)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def img2grey(img: Union[np.ndarray, str], is_rgb: bool = False) -> np.ndarray:
|
| 302 |
+
if isinstance(img, np.ndarray):
|
| 303 |
+
if len(img.shape) == 3:
|
| 304 |
+
if img.shape[-1] != 1:
|
| 305 |
+
if is_rgb:
|
| 306 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| 307 |
+
else:
|
| 308 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 309 |
+
else:
|
| 310 |
+
img = img[..., 0]
|
| 311 |
+
return img
|
| 312 |
+
elif isinstance(img, str):
|
| 313 |
+
return cv2.imread(img, cv2.IMREAD_GRAYSCALE)
|
| 314 |
+
else:
|
| 315 |
+
raise NotImplementedError
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def pad2divisior(img: np.ndarray, divisior: int, value = (0, 0, 0)) -> np.ndarray:
|
| 319 |
+
im_h, im_w = img.shape[:2]
|
| 320 |
+
pad_h = int(np.ceil(im_h / divisior)) * divisior - im_h
|
| 321 |
+
pad_w = int(np.ceil(im_w / divisior)) * divisior - im_w
|
| 322 |
+
if pad_h != 0 or pad_w != 0:
|
| 323 |
+
img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, value=value, borderType=cv2.BORDER_CONSTANT)
|
| 324 |
+
return img, (pad_h, pad_w)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def mask2rle(mask: np.ndarray, decode_for_json: bool = True) -> Dict:
|
| 328 |
+
mask_rle = maskUtils.encode(np.array(
|
| 329 |
+
mask[..., np.newaxis] > 0, order='F',
|
| 330 |
+
dtype='uint8'))[0]
|
| 331 |
+
if decode_for_json:
|
| 332 |
+
mask_rle['counts'] = mask_rle['counts'].decode()
|
| 333 |
+
return mask_rle
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def bbox2xyxy(box) -> Tuple[int]:
|
| 337 |
+
x1, y1 = box[0], box[1]
|
| 338 |
+
return x1, y1, x1+box[2], y1+box[3]
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def bbox_overlap_area(abox, boxb) -> int:
|
| 342 |
+
ax1, ay1, ax2, ay2 = bbox2xyxy(abox)
|
| 343 |
+
bx1, by1, bx2, by2 = bbox2xyxy(boxb)
|
| 344 |
+
|
| 345 |
+
ix = min(ax2, bx2) - max(ax1, bx1)
|
| 346 |
+
iy = min(ay2, by2) - max(ay1, by1)
|
| 347 |
+
|
| 348 |
+
if ix > 0 and iy > 0:
|
| 349 |
+
return ix * iy
|
| 350 |
+
else:
|
| 351 |
+
return 0
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def bbox_overlap_xy(abox, boxb) -> Tuple[int]:
|
| 355 |
+
ax1, ay1, ax2, ay2 = bbox2xyxy(abox)
|
| 356 |
+
bx1, by1, bx2, by2 = bbox2xyxy(boxb)
|
| 357 |
+
|
| 358 |
+
ix = min(ax2, bx2) - max(ax1, bx1)
|
| 359 |
+
iy = min(ay2, by2) - max(ay1, by1)
|
| 360 |
+
|
| 361 |
+
return ix, iy
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def xyxy_overlap_area(axyxy, bxyxy) -> int:
|
| 365 |
+
ax1, ay1, ax2, ay2 = axyxy
|
| 366 |
+
bx1, by1, bx2, by2 = bxyxy
|
| 367 |
+
|
| 368 |
+
ix = min(ax2, bx2) - max(ax1, bx1)
|
| 369 |
+
iy = min(ay2, by2) - max(ay1, by1)
|
| 370 |
+
|
| 371 |
+
if ix > 0 and iy > 0:
|
| 372 |
+
return ix * iy
|
| 373 |
+
else:
|
| 374 |
+
return 0
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
DIRNAME2TAG = {'rezero': 're:zero'}
|
| 378 |
+
def dirname2charactername(dirname, start=6):
|
| 379 |
+
cname = dirname[start:]
|
| 380 |
+
for k, v in DIRNAME2TAG.items():
|
| 381 |
+
cname = cname.replace(k, v)
|
| 382 |
+
return cname
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def imglist2grid(imglist: np.ndarray, grid_size: int = 384, col=None) -> np.ndarray:
|
| 386 |
+
sqimlist = []
|
| 387 |
+
for img in imglist:
|
| 388 |
+
sqimlist.append(square_pad_resize(img, grid_size)[0])
|
| 389 |
+
|
| 390 |
+
nimg = len(imglist)
|
| 391 |
+
if nimg == 0:
|
| 392 |
+
return None
|
| 393 |
+
padn = 0
|
| 394 |
+
if col is None:
|
| 395 |
+
if nimg > 5:
|
| 396 |
+
row = int(np.round(np.sqrt(nimg)))
|
| 397 |
+
col = int(np.ceil(nimg / row))
|
| 398 |
+
else:
|
| 399 |
+
col = nimg
|
| 400 |
+
|
| 401 |
+
padn = int(np.ceil(nimg / col) * col) - nimg
|
| 402 |
+
if padn != 0:
|
| 403 |
+
padimg = np.zeros_like(sqimlist[0])
|
| 404 |
+
for _ in range(padn):
|
| 405 |
+
sqimlist.append(padimg)
|
| 406 |
+
|
| 407 |
+
return rearrange(sqimlist, '(row col) h w c -> (row h) (col w) c', col=col)
|
| 408 |
+
|
| 409 |
+
def write_jsonlines(filep: str, dict_lst: List[str], progress_bar: bool = True):
|
| 410 |
+
with open(filep, 'w') as out:
|
| 411 |
+
if progress_bar:
|
| 412 |
+
lst = tqdm(dict_lst)
|
| 413 |
+
else:
|
| 414 |
+
lst = dict_lst
|
| 415 |
+
for ddict in lst:
|
| 416 |
+
jout = json.dumps(ddict) + '\n'
|
| 417 |
+
out.write(jout)
|
| 418 |
+
|
| 419 |
+
def read_jsonlines(filep: str):
|
| 420 |
+
with open(filep, 'r', encoding='utf8') as f:
|
| 421 |
+
result = [json.loads(jline) for jline in f.read().splitlines()]
|
| 422 |
+
return result
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def _b64encode(x: bytes) -> str:
|
| 426 |
+
return base64.b64encode(x).decode("utf-8")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def img2b64(img):
|
| 430 |
+
"""
|
| 431 |
+
Convert a PIL image to a base64-encoded string.
|
| 432 |
+
"""
|
| 433 |
+
if isinstance(img, np.ndarray):
|
| 434 |
+
img = Image.fromarray(img)
|
| 435 |
+
buffered = io.BytesIO()
|
| 436 |
+
img.save(buffered, format='PNG')
|
| 437 |
+
return _b64encode(buffered.getvalue())
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def save_encoded_image(b64_image: str, output_path: str):
|
| 441 |
+
with open(output_path, "wb") as image_file:
|
| 442 |
+
image_file.write(base64.b64decode(b64_image))
|
| 443 |
+
|
| 444 |
+
def submit_request(url, data, exist_on_exception=True, auth=None, wait_time = 30):
|
| 445 |
+
response = None
|
| 446 |
+
try:
|
| 447 |
+
while True:
|
| 448 |
+
try:
|
| 449 |
+
response = requests.post(url, data=data, auth=auth)
|
| 450 |
+
response.raise_for_status()
|
| 451 |
+
break
|
| 452 |
+
except Exception as e:
|
| 453 |
+
if wait_time > 0:
|
| 454 |
+
print(traceback.format_exc(), file=sys.stderr)
|
| 455 |
+
print(f'sleep {wait_time} sec...')
|
| 456 |
+
time.sleep(wait_time)
|
| 457 |
+
continue
|
| 458 |
+
else:
|
| 459 |
+
raise e
|
| 460 |
+
except Exception as e:
|
| 461 |
+
print(traceback.format_exc(), file=sys.stderr)
|
| 462 |
+
if response is not None:
|
| 463 |
+
print('response content: ' + response.text)
|
| 464 |
+
if exist_on_exception:
|
| 465 |
+
exit()
|
| 466 |
+
return response
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
# def resize_image(input_image, resolution):
|
| 470 |
+
# H, W = input_image.shape[:2]
|
| 471 |
+
# k = float(min(resolution)) / min(H, W)
|
| 472 |
+
# img = cv2.resize(input_image, resolution, interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
| 473 |
+
# return img
|
utils/logger.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os.path as osp
|
| 3 |
+
from termcolor import colored
|
| 4 |
+
|
| 5 |
+
def set_logging(name=None, verbose=True):
|
| 6 |
+
for handler in logging.root.handlers[:]:
|
| 7 |
+
logging.root.removeHandler(handler)
|
| 8 |
+
# Sets level and returns logger
|
| 9 |
+
# rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
| 10 |
+
fmt = (
|
| 11 |
+
# colored("[%(name)s]", "magenta", attrs=["bold"])
|
| 12 |
+
colored("[%(asctime)s]", "blue")
|
| 13 |
+
+ colored("%(levelname)s:", "green")
|
| 14 |
+
+ colored("%(message)s", "white")
|
| 15 |
+
)
|
| 16 |
+
logging.basicConfig(format=fmt, level=logging.INFO if verbose else logging.WARNING)
|
| 17 |
+
return logging.getLogger(name)
|
| 18 |
+
|
| 19 |
+
LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.)
|
| 20 |
+
|
utils/mmdet_custom_hooks.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mmengine.fileio import FileClient
|
| 2 |
+
from mmengine.dist import master_only
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
import torch
|
| 5 |
+
import mmcv
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os.path as osp
|
| 8 |
+
import cv2
|
| 9 |
+
from typing import Optional, Sequence
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from mmdet.apis import inference_detector
|
| 12 |
+
from mmcv.transforms import Compose
|
| 13 |
+
from mmdet.engine import DetVisualizationHook
|
| 14 |
+
from mmdet.registry import HOOKS
|
| 15 |
+
from mmdet.structures import DetDataSample
|
| 16 |
+
|
| 17 |
+
from utils.io_utils import find_all_imgs, square_pad_resize, imglist2grid
|
| 18 |
+
|
| 19 |
+
def inference_detector(
|
| 20 |
+
model: nn.Module,
|
| 21 |
+
imgs,
|
| 22 |
+
test_pipeline
|
| 23 |
+
):
|
| 24 |
+
|
| 25 |
+
if isinstance(imgs, (list, tuple)):
|
| 26 |
+
is_batch = True
|
| 27 |
+
else:
|
| 28 |
+
imgs = [imgs]
|
| 29 |
+
is_batch = False
|
| 30 |
+
|
| 31 |
+
if len(imgs) == 0:
|
| 32 |
+
return []
|
| 33 |
+
|
| 34 |
+
test_pipeline = test_pipeline.copy()
|
| 35 |
+
if isinstance(imgs[0], np.ndarray):
|
| 36 |
+
# Calling this method across libraries will result
|
| 37 |
+
# in module unregistered error if not prefixed with mmdet.
|
| 38 |
+
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
|
| 39 |
+
|
| 40 |
+
test_pipeline = Compose(test_pipeline)
|
| 41 |
+
|
| 42 |
+
result_list = []
|
| 43 |
+
for img in imgs:
|
| 44 |
+
# prepare data
|
| 45 |
+
if isinstance(img, np.ndarray):
|
| 46 |
+
# TODO: remove img_id.
|
| 47 |
+
data_ = dict(img=img, img_id=0)
|
| 48 |
+
else:
|
| 49 |
+
# TODO: remove img_id.
|
| 50 |
+
data_ = dict(img_path=img, img_id=0)
|
| 51 |
+
# build the data pipeline
|
| 52 |
+
data_ = test_pipeline(data_)
|
| 53 |
+
|
| 54 |
+
data_['inputs'] = [data_['inputs']]
|
| 55 |
+
data_['data_samples'] = [data_['data_samples']]
|
| 56 |
+
|
| 57 |
+
# forward the model
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
results = model.test_step(data_)[0]
|
| 60 |
+
|
| 61 |
+
result_list.append(results)
|
| 62 |
+
|
| 63 |
+
if not is_batch:
|
| 64 |
+
return result_list[0]
|
| 65 |
+
else:
|
| 66 |
+
return result_list
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@HOOKS.register_module()
|
| 70 |
+
class InstanceSegVisualizationHook(DetVisualizationHook):
|
| 71 |
+
|
| 72 |
+
def __init__(self, visualize_samples: str = '',
|
| 73 |
+
read_rgb: bool = False,
|
| 74 |
+
draw: bool = False,
|
| 75 |
+
interval: int = 50,
|
| 76 |
+
score_thr: float = 0.3,
|
| 77 |
+
show: bool = False,
|
| 78 |
+
wait_time: float = 0.,
|
| 79 |
+
test_out_dir: Optional[str] = None,
|
| 80 |
+
file_client_args: dict = dict(backend='disk')):
|
| 81 |
+
super().__init__(draw, interval, score_thr, show, wait_time, test_out_dir, file_client_args)
|
| 82 |
+
self.vis_samples = []
|
| 83 |
+
|
| 84 |
+
if osp.exists(visualize_samples):
|
| 85 |
+
self.channel_order = channel_order = 'rgb' if read_rgb else 'bgr'
|
| 86 |
+
samples = find_all_imgs(visualize_samples, abs_path=True)
|
| 87 |
+
for imgp in samples:
|
| 88 |
+
img = mmcv.imread(imgp, channel_order=channel_order)
|
| 89 |
+
img, _, _, _ = square_pad_resize(img, 640)
|
| 90 |
+
self.vis_samples.append(img)
|
| 91 |
+
|
| 92 |
+
def before_val(self, runner) -> None:
|
| 93 |
+
total_curr_iter = runner.iter
|
| 94 |
+
self._visualize_data(total_curr_iter, runner)
|
| 95 |
+
return super().before_val(runner)
|
| 96 |
+
|
| 97 |
+
# def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
| 98 |
+
# outputs: Sequence[DetDataSample]) -> None:
|
| 99 |
+
# """Run after every ``self.interval`` validation iterations.
|
| 100 |
+
|
| 101 |
+
# Args:
|
| 102 |
+
# runner (:obj:`Runner`): The runner of the validation process.
|
| 103 |
+
# batch_idx (int): The index of the current batch in the val loop.
|
| 104 |
+
# data_batch (dict): Data from dataloader.
|
| 105 |
+
# outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples
|
| 106 |
+
# that contain annotations and predictions.
|
| 107 |
+
# """
|
| 108 |
+
# # if self.draw is False:
|
| 109 |
+
# # return
|
| 110 |
+
|
| 111 |
+
# if self.file_client is None:
|
| 112 |
+
# self.file_client = FileClient(**self.file_client_args)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# # There is no guarantee that the same batch of images
|
| 116 |
+
# # is visualized for each evaluation.
|
| 117 |
+
# total_curr_iter = runner.iter + batch_idx
|
| 118 |
+
|
| 119 |
+
# # # Visualize only the first data
|
| 120 |
+
# # img_path = outputs[0].img_path
|
| 121 |
+
# # img_bytes = self.file_client.get(img_path)
|
| 122 |
+
# # img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
| 123 |
+
# if total_curr_iter % self.interval == 0 and self.vis_samples:
|
| 124 |
+
# self._visualize_data(total_curr_iter, runner)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@master_only
|
| 128 |
+
def _visualize_data(self, total_curr_iter, runner):
|
| 129 |
+
|
| 130 |
+
tgt_size = 384
|
| 131 |
+
|
| 132 |
+
runner.model.eval()
|
| 133 |
+
outputs = inference_detector(runner.model, self.vis_samples, test_pipeline=runner.cfg.test_pipeline)
|
| 134 |
+
vis_results = []
|
| 135 |
+
for img, output in zip(self.vis_samples, outputs):
|
| 136 |
+
vis_img = self.add_datasample(
|
| 137 |
+
'val_img',
|
| 138 |
+
img,
|
| 139 |
+
data_sample=output,
|
| 140 |
+
show=self.show,
|
| 141 |
+
wait_time=self.wait_time,
|
| 142 |
+
pred_score_thr=self.score_thr,
|
| 143 |
+
draw_gt=False,
|
| 144 |
+
step=total_curr_iter)
|
| 145 |
+
vis_results.append(cv2.resize(vis_img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA))
|
| 146 |
+
|
| 147 |
+
drawn_img = imglist2grid(vis_results, tgt_size)
|
| 148 |
+
if drawn_img is None:
|
| 149 |
+
return
|
| 150 |
+
drawn_img = cv2.cvtColor(drawn_img, cv2.COLOR_BGR2RGB)
|
| 151 |
+
visualizer = self._visualizer
|
| 152 |
+
visualizer.set_image(drawn_img)
|
| 153 |
+
visualizer.add_image('val_img', drawn_img, total_curr_iter)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@master_only
|
| 157 |
+
def add_datasample(
|
| 158 |
+
self,
|
| 159 |
+
name: str,
|
| 160 |
+
image: np.ndarray,
|
| 161 |
+
data_sample: Optional['DetDataSample'] = None,
|
| 162 |
+
draw_gt: bool = True,
|
| 163 |
+
draw_pred: bool = True,
|
| 164 |
+
show: bool = False,
|
| 165 |
+
wait_time: float = 0,
|
| 166 |
+
# TODO: Supported in mmengine's Viusalizer.
|
| 167 |
+
out_file: Optional[str] = None,
|
| 168 |
+
pred_score_thr: float = 0.3,
|
| 169 |
+
step: int = 0) -> np.ndarray:
|
| 170 |
+
image = image.clip(0, 255).astype(np.uint8)
|
| 171 |
+
visualizer = self._visualizer
|
| 172 |
+
classes = visualizer.dataset_meta.get('classes', None)
|
| 173 |
+
palette = visualizer.dataset_meta.get('palette', None)
|
| 174 |
+
|
| 175 |
+
gt_img_data = None
|
| 176 |
+
pred_img_data = None
|
| 177 |
+
|
| 178 |
+
if data_sample is not None:
|
| 179 |
+
data_sample = data_sample.cpu()
|
| 180 |
+
|
| 181 |
+
if draw_gt and data_sample is not None:
|
| 182 |
+
gt_img_data = image
|
| 183 |
+
if 'gt_instances' in data_sample:
|
| 184 |
+
gt_img_data = visualizer._draw_instances(image,
|
| 185 |
+
data_sample.gt_instances,
|
| 186 |
+
classes, palette)
|
| 187 |
+
|
| 188 |
+
if 'gt_panoptic_seg' in data_sample:
|
| 189 |
+
assert classes is not None, 'class information is ' \
|
| 190 |
+
'not provided when ' \
|
| 191 |
+
'visualizing panoptic ' \
|
| 192 |
+
'segmentation results.'
|
| 193 |
+
gt_img_data = visualizer._draw_panoptic_seg(
|
| 194 |
+
gt_img_data, data_sample.gt_panoptic_seg, classes)
|
| 195 |
+
|
| 196 |
+
if draw_pred and data_sample is not None:
|
| 197 |
+
pred_img_data = image
|
| 198 |
+
if 'pred_instances' in data_sample:
|
| 199 |
+
pred_instances = data_sample.pred_instances
|
| 200 |
+
pred_instances = pred_instances[
|
| 201 |
+
pred_instances.scores > pred_score_thr]
|
| 202 |
+
pred_img_data = visualizer._draw_instances(image, pred_instances,
|
| 203 |
+
classes, palette)
|
| 204 |
+
if 'pred_panoptic_seg' in data_sample:
|
| 205 |
+
assert classes is not None, 'class information is ' \
|
| 206 |
+
'not provided when ' \
|
| 207 |
+
'visualizing panoptic ' \
|
| 208 |
+
'segmentation results.'
|
| 209 |
+
pred_img_data = visualizer._draw_panoptic_seg(
|
| 210 |
+
pred_img_data, data_sample.pred_panoptic_seg.numpy(),
|
| 211 |
+
classes)
|
| 212 |
+
|
| 213 |
+
if gt_img_data is not None and pred_img_data is not None:
|
| 214 |
+
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
|
| 215 |
+
elif gt_img_data is not None:
|
| 216 |
+
drawn_img = gt_img_data
|
| 217 |
+
elif pred_img_data is not None:
|
| 218 |
+
drawn_img = pred_img_data
|
| 219 |
+
else:
|
| 220 |
+
# Display the original image directly if nothing is drawn.
|
| 221 |
+
drawn_img = image
|
| 222 |
+
|
| 223 |
+
return drawn_img
|