diff --git a/Anymate/.gitignore b/Anymate/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..d4162f599d7bbc11d4337c8fdeecc8789dcb909e
--- /dev/null
+++ b/Anymate/.gitignore
@@ -0,0 +1,26 @@
+__pycache__
+*.pt
+*.tar
+*.tar
+*.txt
+*.glb*
+*.obj
+*.ckpt
+*.blend
+*.blend1
+test_*
+
+blender-*
+*.json*
+*.glb
+*.gltf
+*.fbx
+*.FBX
+*.dae
+*.obj
+*.mtl
+*.binvox
+*.csv
+*.tga
+*.png
+*.jpg
\ No newline at end of file
diff --git a/Anymate/__init__.py b/Anymate/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Anymate/args.py b/Anymate/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..715dedbb118882b926d0480978a56fa7cd710ca4
--- /dev/null
+++ b/Anymate/args.py
@@ -0,0 +1,22 @@
+class AnymateArgs:
+ def __init__(self):
+ # self.encoder = "miche"
+ # self.decoder = "transformer_latent"
+ # self.dataset = "train"
+ # self.run_name = "miche-transformer_latent-train-8gpu-finetune"
+ self.checkpoint_joint = "Anymate/checkpoints/joint/bert-transformer_latent-train-8gpu-finetune.pth.tar"
+ self.checkpoint_conn = "Anymate/checkpoints/conn/bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar"
+ self.checkpoint_skin = "Anymate/checkpoints/skin/bert-attendjoints_combine-train-8gpu-finetune.pth.tar"
+
+ self.device = "cuda"
+ self.num_joints = 96
+
+
+class UIArgs:
+ def __init__(self):
+ self.checkpoint_joint = "Anymate/checkpoints/joint/bert-transformer_latent-train-8gpu-finetune.pth.tar"
+ self.checkpoint_conn = "Anymate/checkpoints/conn/bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar"
+ self.checkpoint_skin = "Anymate/checkpoints/skin/bert-attendjoints_combine-train-8gpu-finetune.pth.tar"
+
+ui_args = UIArgs()
+anymate_args = AnymateArgs()
\ No newline at end of file
diff --git a/Anymate/blender_script.py b/Anymate/blender_script.py
new file mode 100644
index 0000000000000000000000000000000000000000..a55226add7c137073d90555380790d1a34133441
--- /dev/null
+++ b/Anymate/blender_script.py
@@ -0,0 +1,747 @@
+import bpy
+import mathutils
+from mathutils import Vector, Matrix
+
+import os
+import sys
+import random
+import numpy as np
+import json
+import argparse
+
+
+IMPORT_FUNCTIONS = {
+ "obj": bpy.ops.wm.obj_import,
+ "glb": bpy.ops.import_scene.gltf,
+ "gltf": bpy.ops.import_scene.gltf,
+ "usd": bpy.ops.import_scene.usd,
+ "fbx": bpy.ops.import_scene.fbx,
+ "stl": bpy.ops.import_mesh.stl,
+ "usda": bpy.ops.import_scene.usda,
+ "dae": bpy.ops.wm.collada_import,
+ "ply": bpy.ops.import_mesh.ply,
+ "abc": bpy.ops.wm.alembic_import,
+ "blend": bpy.ops.wm.append,
+}
+
+def load_object(object_path: str) -> None:
+ """Loads a model with a supported file extension into the scene.
+
+ Args:
+ object_path (str): Path to the model file.
+
+ Raises:
+ ValueError: If the file extension is not supported.
+
+ Returns:
+ None
+ """
+ file_extension = object_path.split(".")[-1].lower()
+ if file_extension is None:
+ raise ValueError(f"Unsupported file type: {object_path}")
+
+ # load from existing import functions
+ import_function = IMPORT_FUNCTIONS[file_extension]
+
+ if file_extension == "blend":
+ import_function(directory=object_path, link=False)
+ elif file_extension in {"glb", "gltf"}:
+ import_function(filepath=object_path, merge_vertices=True)
+ else:
+ import_function(filepath=object_path)
+
+####################### save json ################################
+def save_json(output_path, mesh_obj, armature_obj, extra=None, arm_name=False):
+ # makedirs output_path
+ os.makedirs(output_path, exist_ok=True)
+
+ # start retrieve the information of mesh, skining and rigging
+
+ #1. retrieve the information of rigging, save the world matrix of the amature object
+ total_armature_info = {}
+ for obj in armature_obj:
+ # depsgraph = bpy.context.evaluated_depsgraph_get()
+ # obj = obj.evaluated_get(depsgraph)
+ armature_info = {}
+ armature_info["world_matrix"] = [list(row) for row in obj.matrix_world.copy()]
+ translation = obj.matrix_world.translation
+ for bone in obj.pose.bones:
+ bone_info = {}
+ bone_info["head_local"] = list(bone.head.copy())
+ bone_info["head_world"] = list((obj.matrix_world.to_3x3() @ bone.head+translation).copy())
+ # bone_info["matrix_local"] = [list(row) for row in bone.matrix_local.copy()]
+ bone_info["tail_local"] = list(bone.tail.copy())
+ bone_info["tail_world"] = list((obj.matrix_world.to_3x3() @ bone.tail+translation).copy())
+
+ if bone.parent:
+ bone_info["parent"] = bone.parent.name.replace(" ", "_")
+ if arm_name:
+ bone_info["parent"] = obj.name + "--" + bone_info["parent"]
+ else:
+ bone_info["parent"] = None
+ bone_info["children"] = []
+ if bone.children:
+ for child in bone.children:
+ if arm_name:
+ bone_info["children"].append(obj.name + "--" + child.name.replace(" ", "_"))
+ else:
+ bone_info["children"].append(child.name.replace(" ", "_"))
+ bone_name = bone.name.replace(" ", "_")
+ if arm_name:
+ bone_name = obj.name + "--" + bone_name
+ armature_info[bone_name] = bone_info
+ obj_name = obj.name.replace(" ", "_")
+ total_armature_info[obj.name] = armature_info
+
+
+ #2. retrieve the informatioon of skining
+ total_skinning_info = {}
+ for obj in mesh_obj:
+ vertex_groups = obj.vertex_groups
+ # if not vertex_groups:
+ # continue
+ # for group in vertex_groups:
+ skinning_info = {}
+ skinning_info["world_matrix"] = [list(row) for row in obj.matrix_world.copy()]
+ weight_info = []
+ for vertex in obj.data.vertices:
+ vertex_info = {}
+ for group in vertex.groups:
+ name = vertex_groups[group.group].name
+ name = name.replace(" ", "_")
+ if arm_name:
+ arm_modifier = [modifier for modifier in obj.modifiers if modifier.type == 'ARMATURE']
+ assert(len(arm_modifier) == 1)
+ name = arm_modifier[0].object.name + "--" + name
+ weight = group.weight
+ vertex_info[name] = weight
+ weight_info.append(vertex_info)
+ skinning_info["weight"] = weight_info
+ obj_name = obj.name.replace(" ", "_")
+ total_skinning_info[obj_name]=skinning_info
+
+
+ rigging_file_path = os.path.join(output_path, "rigging.json")
+ if extra:
+ rigging_file_path = rigging_file_path.replace("rigging.json", f'rigging_{extra}.json')
+ with open(rigging_file_path, "w") as f:
+ json.dump(total_armature_info, f, indent = 2)
+
+ skining_file_path = os.path.join(output_path, "skining.json")
+ if extra:
+ skining_file_path = skining_file_path.replace("skining.json", f'skining_{extra}.json')
+ with open(skining_file_path, "w") as f:
+ json.dump(total_skinning_info, f , indent = 2)
+
+
+ return rigging_file_path
+
+
+def apply_skinning_weights(json_file):
+
+ with open(json_file, "r") as f:
+ skinning_data = json.load(f)
+
+ armature_obj = bpy.data.objects.get("Armature")
+ if not armature_obj:
+ print("Error: Armature object 'Armature' not found.")
+ return
+
+ # 将所有网格对象放置在骨骼对象的子集中
+ count = 0
+ for obj in bpy.context.scene.objects:
+ if obj.type == 'MESH':
+ obj.parent = armature_obj
+ count += 1
+
+ print("total mesh count:", count)
+
+ for obj in bpy.context.scene.objects:
+ vertex_index = 0
+ if obj.type == 'MESH':
+ mesh_name = obj.name
+ if mesh_name in skinning_data:
+ skinning_info = skinning_data[mesh_name]
+ if "weight" in skinning_info:
+ print("Applying skinning data for mesh:", mesh_name)
+ vertex_index = 0
+ for vertex_weight in skinning_info["weight"]:
+ for bone_name, weight_value in vertex_weight.items():
+ vertex_group = obj.vertex_groups.get(bone_name)
+ if vertex_group is None:
+ vertex_group = obj.vertex_groups.new(name=bone_name)
+ print("Vertex group created:", bone_name)
+ vertex_group.add([vertex_index], weight_value, 'REPLACE')
+ vertex_index += 1
+ else:
+ print("No skinning data found for mesh:", mesh_name)
+ for obj in bpy.context.scene.objects:
+ if obj.type == 'MESH':
+ modifier = obj.modifiers.new(name="Armature", type='ARMATURE')
+ modifier.object = armature_obj
+ modifier.use_vertex_groups = True
+ print("Armature modifier added to mesh:", obj.name)
+
+def reload_rigging(rigging_file_path):
+ with open(rigging_file_path, "r") as f:
+ total_armature_info = json.load(f)
+
+ bpy.ops.object.armature_add()
+ armature_obj = bpy.context.object
+ armature_obj.name = "Armature"
+
+ bpy.ops.object.mode_set(mode='EDIT')
+ bpy.ops.armature.select_all(action='SELECT')
+ bpy.ops.armature.delete()
+ bpy.ops.object.mode_set(mode='OBJECT')
+ bpy.ops.object.mode_set(mode='EDIT')
+
+ world_matrix = mathutils.Matrix([[1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]])
+ armature_obj.matrix_world = world_matrix
+
+ for armature_name, armature_info in total_armature_info.items():
+ for bone_name, bone_info in armature_info.items():
+ if bone_name == "world_matrix":
+ continue
+ bone = armature_obj.data.edit_bones.new(bone_name)
+ bone.head = bone_info["head_world"]
+ bone.tail = bone_info["tail_world"]
+
+ for bone_name, bone_info in armature_info.items():
+ if bone_name == "world_matrix":
+ continue
+ bone = armature_obj.data.edit_bones[bone_name]
+ parent_name = bone_info["parent"]
+ if parent_name:
+ parent_bone = armature_obj.data.edit_bones[parent_name]
+ bone.parent = parent_bone
+ edit_len = len(armature_obj.data.edit_bones.keys())
+ bpy.ops.object.mode_set(mode='OBJECT')
+ bone_len = len(armature_obj.data.bones.keys())
+ assert(edit_len == bone_len, "bone number not match!" + str(edit_len) + " " + str(bone_len))
+ bpy.ops.object.select_all(action='DESELECT')
+ armature_obj.select_set(True)
+ bpy.context.view_layer.objects.active = armature_obj
+ print("Rigging information has been reloaded!")
+
+############################# reload json ################################
+def reload_json(folder_path, version=0, export = None):
+ bpy.ops.wm.read_homefile(use_empty=True)
+ if version == 0:
+ obj_path = os.path.join(folder_path, "object.obj")
+ skinning_file_path = os.path.join(folder_path, "skining.json")
+ rigging_file_path = os.path.join(folder_path, "rigging.json")
+ elif version == 1:
+ obj_path = os.path.join(folder_path, "join.obj")
+ skinning_file_path = os.path.join(folder_path, "skining_norig.json")
+ rigging_file_path = os.path.join(folder_path, "rigging_norig.json")
+ elif version == 2:
+ obj_path = os.path.join(folder_path, "join.obj")
+ skinning_file_path = os.path.join(folder_path, "skining_norig2.json")
+ rigging_file_path = os.path.join(folder_path, "rigging_norig2.json")
+ # import_obj(obj_path)
+ load_object(obj_path)
+ reload_rigging(rigging_file_path)
+ apply_skinning_weights(skinning_file_path)
+ if export:
+ bpy.ops.wm.save_as_mainfile(filepath=export)
+ print("Done!")
+
+
+def reset_scene() -> None:
+ """Resets the scene to a clean state.
+
+ Returns:
+ None
+ """
+ # delete everything that isn't part of a camera or a light
+ for obj in bpy.data.objects:
+ if obj.type not in {"CAMERA", "LIGHT"}:
+ bpy.data.objects.remove(obj, do_unlink=True)
+
+ # delete all the materials
+ for material in bpy.data.materials:
+ bpy.data.materials.remove(material, do_unlink=True)
+
+ # delete all the textures
+ for texture in bpy.data.textures:
+ bpy.data.textures.remove(texture, do_unlink=True)
+
+ # delete all the images
+ for image in bpy.data.images:
+ bpy.data.images.remove(image, do_unlink=True)
+
+
+def save_mesh(path, mtl=False, obj_path=None):
+ if mtl:
+ # save the blend file
+ bpy.ops.wm.save_as_mainfile(filepath=obj_path + '/object.blend')
+ # reopen the blend file
+ bpy.ops.wm.open_mainfile(filepath=obj_path + '/object.blend')
+ # unpack all the materials and textures to obj_path
+ bpy.ops.file.unpack_all(method='WRITE_LOCAL')
+ # save to .obj without material
+ bpy.ops.wm.obj_export(filepath=path, export_materials=mtl, export_uv=mtl, export_triangulated_mesh=True)
+
+
+def get_root_obj(obj):
+ if not obj.parent:
+ return obj
+ return get_root_obj(obj.parent)
+
+def normalize(objs):
+ # bpy.ops.object.select_all(action='DESELECT')
+ # # select objs and join them
+ # for obj in objs:
+ # obj.select_set(True)
+ # bpy.context.view_layer.objects.active = objs[0]
+ # name_join = objs[0].name
+ # bpy.ops.object.join()
+ # obj_join = bpy.context.active_object
+ # print(obj_join.matrix_world)
+ # print(name_join)
+ # assert(name_join == obj_join.name)
+
+ objs_eval = []
+ depsgraph = bpy.context.evaluated_depsgraph_get()
+ for obj in objs:
+ objs_eval.append(obj.evaluated_get(depsgraph))
+
+ vertices = []
+ for obj in objs_eval:
+ for v in obj.data.vertices:
+ vertices.append(obj.matrix_world @ Vector((v.co.x, v.co.y, v.co.z, 1)))
+
+ vertices = np.array(vertices)
+ min_x, min_y, min_z, _ = np.min(vertices, axis=0)
+ max_x, max_y, max_z, _ = np.max(vertices, axis=0)
+
+ # print(min_x, min_y, min_z)
+ # print(max_x, max_y, max_z)
+
+ scale_x = 1 / (max_x - min_x)
+ scale_y = 1 / (max_y - min_y)
+ scale_z = 1 / (max_z - min_z)
+ scale_min = min(scale_x, scale_y, scale_z)
+
+ assert scale_min < 1e6
+
+ translate_x = - (max_x + min_x) / 2 * scale_min
+ translate_y = - (max_y + min_y) / 2 * scale_min
+ translate_z = - min_z * scale_min
+
+ # form transformation matrix
+ trans = Matrix.Translation((translate_x, translate_y, translate_z))
+
+ scale = Matrix.Scale(scale_min, 4, (1, 0, 0)) @ Matrix.Scale(scale_min, 4, (0, 1, 0)) @ Matrix.Scale(scale_min, 4, (0, 0, 1))
+
+ # print(trans, scale)
+
+
+ root = get_root_obj(objs[0])
+ # print(root.name)
+ # print(root.scale)
+ # print(root.location)
+ # print(root.matrix_world)
+ # root.location = mathutils.Vector(root.location) + mathutils.Vector((translate_x, translate_y, translate_z))
+ # root.scale = mathutils.Vector(root.scale) * mathutils.Vector((scale_x, scale_y, scale_z))
+
+ # add the extra transformation to the root object's world matrix
+ root.matrix_world = trans @ scale @ root.matrix_world
+ # print(root.name)
+ # print(root.scale)
+ # print(root.location)
+ # print(root.matrix_world)
+
+ # refresh
+ bpy.context.view_layer.update()
+
+ ######### check if its successful
+ # objs_eval = []
+ # depsgraph = bpy.context.evaluated_depsgraph_get()
+ # for obj in objs:
+ # objs_eval.append(obj.evaluated_get(depsgraph))
+
+ # vertices = []
+ # for obj in objs_eval:
+ # for v in obj.data.vertices:
+ # vertices.append(obj.matrix_world @ Vector((v.co.x, v.co.y, v.co.z, 1)))
+
+ # vertices = np.array(vertices)
+ # min_x, min_y, min_z, _ = np.min(vertices, axis=0)
+ # max_x, max_y, max_z, _ = np.max(vertices, axis=0)
+
+ # print(min_x, min_y, min_z)
+ # print(max_x, max_y, max_z)
+
+def remesh(objs, target=5000):
+ num_v = {}
+ for obj in objs:
+ num_v[obj] = len(obj.data.vertices)
+
+ # sort the num_v dict and make it a dict again
+ num_v_sort = sorted(num_v.items(), key=lambda x: x[1], reverse=True)
+
+ # print(num_v_sort)
+ total_v = sum([num_v[obj] for obj in num_v])
+
+ iters = 0
+ while total_v > target and iters<20:
+ reduce = []
+ for obj, v in num_v_sort:
+ reduce.append(obj)
+ if sum([num_v[oo] for oo in reduce]) > 0.5 * total_v:
+ break
+ for obj in reduce:
+ # check if have shape key
+ if obj.data.shape_keys is not None:
+ # remove obj from num_v
+ num_v.pop(obj)
+ continue
+
+ ratio = 0.5
+ # apply decimate modifier
+ bpy.context.view_layer.objects.active = obj
+ bpy.ops.object.modifier_add(type='DECIMATE')
+ bpy.context.object.modifiers["Decimate"].ratio = ratio
+ bpy.ops.object.modifier_apply(modifier="Decimate")
+ # update num_v
+ num_v[obj] = len(obj.data.vertices)
+ total_v = sum([num_v[obj] for obj in num_v])
+ num_v_sort = sorted(num_v.items(), key=lambda x: x[1], reverse=True)
+ # print(num_v_sort)
+ iters+=1
+
+
+def get_parents(obj):
+ if not obj.parent:
+ return [obj.name]
+ parents = get_parents(obj.parent)
+ parents.append(obj.name)
+ return parents
+
+def check(objs, arm):
+ # assert('Sketchfab_model' in bpy.data.objects)
+
+ # root_arm = get_root_obj(arm)
+ # for obj in objs:
+ # if root_arm != get_root_obj(obj):
+ # print('not same root')
+ # return -1
+ # return 1
+
+ # action_num = 0
+ # actions = bpy.data.actions
+ # for act in actions:
+ # action_num += 1
+ # fcurves = act.fcurves
+ # data_paths = []
+ # not_pose = False
+ # for fcurve in fcurves:
+ # data_paths.append(fcurve.data_path)
+ # if not fcurve.data_path.startswith('pose.bones'):
+ # # print(fcurve.data_path)
+ # not_pose = True
+ # # return -1
+ # if not_pose:
+ # print('zyhsb')
+ # print(data_paths)
+ # return -1
+ # return action_num
+
+ for obj in objs:
+ vertex_groups = obj.vertex_groups
+ # if not vertex_groups:
+ # continue
+ # for group in vertex_groups:
+ for vertex in obj.data.vertices:
+ vertex_info = {}
+ for group in vertex.groups:
+ name = vertex_groups[group.group].name
+ name = name.replace(" ", "_")
+ if True:
+ arm_modifier = [modifier for modifier in obj.modifiers if modifier.type == 'ARMATURE']
+ if len(arm_modifier) != 1:
+ print('zyhsb', len(arm_modifier))
+ return -2
+ # name = arm_modifier[0].object.name + "--" + name
+ return 1
+
+ # for obj in objs:
+ # if obj.data.shape_keys is not None:
+ # return 1
+ # # only 942!!!
+ # return 0
+
+
+def delete(objs):
+ # check if the mesh object has skinning weight
+ for obj in objs:
+ vertex_groups = obj.vertex_groups
+ if not vertex_groups:
+ # delete the object
+ bpy.data.objects.remove(obj)
+ # print('delete!!!')
+ meshes = []
+ for obj in bpy.context.scene.objects:
+ if obj.type == "MESH":
+ meshes.append(obj)
+
+ return meshes
+
+
+def merge_mesh(folder_path, export = None, save_join = True):
+ # output_path = os.path.join(folder_path, "rigging_norig.json")
+ # if os.path.exists(output_path):
+ # print("Already processed folder:", folder_path)
+ # return
+ bpy.ops.wm.read_homefile(use_empty=True)
+ try:
+ reload_json(folder_path)
+ except:
+ print("Error in reloading json file")
+ # remove the folder
+ os.system(f"rm -r {folder_path}")
+ return None, None
+
+ bpy.ops.object.select_all(action='DESELECT')
+ if export:
+ bpy.ops.wm.save_as_mainfile(filepath='reload_' + export)
+
+ meshes = []
+ for obj in bpy.context.scene.objects:
+ if obj.type == "MESH":
+ bpy.context.view_layer.objects.active = obj
+ obj.select_set(True)
+ meshes.append(obj)
+ print("meshes length", len(meshes))
+
+ bpy.ops.object.join()
+ if export:
+ bpy.ops.wm.save_as_mainfile(filepath='join_' + export)
+
+ meshes = []
+ for obj in bpy.context.scene.objects:
+ if obj.type == "MESH":
+ meshes.append(obj)
+ if len(meshes) != 1:
+ bpy.ops.wm.save_as_mainfile(filepath='join_f.blend')
+ assert len(meshes) == 1
+ # remesh(meshes[0])
+
+
+ if save_join:
+ obj_path = os.path.join(folder_path, "object.obj")
+ bpy.ops.wm.obj_export(filepath=obj_path, export_materials=False, export_uv=False, export_triangulated_mesh=True)
+ # mesh = trimesh.load(glb_file_path)
+ # mesh.export(obj_path, file_type='obj')
+
+
+ # save to json file
+ total_armature_count = 0
+ armature_obj = []
+ mesh_obj = []
+ for obj in bpy.context.scene.objects:
+ if obj.type == "ARMATURE":
+ total_armature_count += 1
+ armature_obj.append(obj)
+ if obj.type == "MESH":
+ mesh_obj.append(obj)
+ if total_armature_count == 0:
+ print("No rigging information for the file:", folder_path+"\n")
+ return None, None
+
+
+ ######### delete bones that are not in the vertex group
+ vertex_group_name = [group.name for group in mesh_obj[0].vertex_groups]
+ bpy.context.view_layer.objects.active = armature_obj[0]
+ bpy.ops.object.mode_set(mode='EDIT')
+ edit_bones = armature_obj[0].data.edit_bones
+ bone_delete = set([bone.name for bone in edit_bones]) - set(vertex_group_name)
+ print(f"Deleting {len(bone_delete)} bones")
+ for bone in bone_delete:
+ # if the bone is root, then do not delete it
+ if edit_bones[bone].parent == None:
+ # return len([1 for child in edit_bones[bone].children if child.name in bone_delete])
+ num_children = len(edit_bones[bone].children)
+ if num_children <= 1:
+ edit_bones.remove(edit_bones[bone])
+ continue
+ if num_children > 1:
+ center = mathutils.Vector((0, 0, 0))
+ for child in edit_bones[bone].children:
+ center += child.head
+ center /= num_children
+ min_dist = 1e9
+ for child in edit_bones[bone].children:
+ dist = (child.head - center).length
+ if dist < min_dist:
+ min_dist = dist
+ min_child = child
+ for child in edit_bones[bone].children:
+ if child != min_child:
+ child.parent = min_child
+ edit_bones.remove(edit_bones[bone])
+ continue
+ continue
+ # assign bone's children to bone's parent
+ bone_obj = edit_bones[bone]
+ for child in bone_obj.children:
+ child.parent = bone_obj.parent
+
+ edit_bones.remove(edit_bones[bone])
+ bpy.ops.object.mode_set(mode='OBJECT')
+
+ if export:
+ bpy.ops.wm.save_as_mainfile(filepath='delete_' + export)
+
+ mesh_obj = []
+ armature_obj = []
+ for obj in bpy.context.scene.objects:
+ if obj.type == "MESH":
+ mesh_obj.append(obj)
+ if obj.type == "ARMATURE":
+ armature_obj.append(obj)
+ assert len(mesh_obj) == 1
+ assert len(armature_obj) == 1
+
+ return mesh_obj, armature_obj
+
+
+def process(file_path, obj_path=None, stamp=None, tex=False):
+ # check if obj_path exists
+ # if os.path.exists(obj_path + '/object.obj'):
+ # print('object.obj exists')
+ # return True
+ reset_scene()
+ load_object(file_path)
+ # bpy.ops.import_scene.gltf(filepath=glb_file_path)
+
+ # delete hierarchy collections['glTF_not_exported']
+ if 'glTF_not_exported' in bpy.data.collections:
+ print('DELETE glTF_not_exported')
+ bpy.data.collections.remove(bpy.data.collections['glTF_not_exported'])
+
+ if stamp is not None:
+ # Set the current frame to the stamp value
+ bpy.context.scene.frame_set(stamp)
+ print(f'Set the current frame to {stamp}')
+
+ # Ensure all objects are updated to this frame
+ bpy.context.view_layer.update()
+
+ mesh_obj = []
+ armature_obj = []
+ for obj in bpy.context.scene.objects:
+ if obj.type == "ARMATURE":
+ # if len(armature_obj) > 0:
+ # print(file_path, 'has more than 1 armature')
+ # return -2
+ armature_obj.append(obj)
+ # obj.show_in_front = True
+ armature_obj[-1].data.pose_position = 'POSE'
+ if obj.type == "MESH":
+ mesh_obj.append(obj)
+ # if obj.data.shape_keys is not None:
+ # return False
+
+ # mesh_obj = delete(mesh_obj)
+ # if len(mesh_obj) == 0:
+ # # print('zyhsb -1', file_path, obj_path)
+ # return -1
+ # return check(mesh_obj, armature_obj)
+
+
+ # total_vertices = np.array([len(obj.data.vertices) for obj in mesh_obj]).sum()
+ # if total_vertices < 1000: return
+ # if total_vertices > 10000: remesh(mesh_obj)
+
+
+ # bpy.ops.object.select_all(action='DESELECT')
+ # armature_obj.select_set(True)
+ # execute(bpy.context)
+
+
+ # normalize(mesh_obj)
+
+
+ mesh_obj = delete(mesh_obj)
+ if len(mesh_obj) == 0:
+ # print('zyhsb -1', file_path, obj_path)
+ return -1
+
+
+ save_json(obj_path, mesh_obj, armature_obj, arm_name=True)
+
+
+ if not tex:
+ save_mesh(obj_path + '/object.obj')
+ else:
+ save_mesh(obj_path + '/object.obj', mtl=True, obj_path=obj_path)
+
+
+ mesh_obj, armature_obj = merge_mesh(obj_path)
+ if mesh_obj is None or armature_obj is None:
+ # print('zyhsb -2', file_path, obj_path)
+ return -2
+
+
+ try:
+ normalize(mesh_obj)
+ except:
+ os.system(f"rm -r {obj_path}")
+ # print('zyhsb -3', file_path, obj_path)
+ return -3
+
+
+ save_json(obj_path, mesh_obj, armature_obj)
+
+ if not tex:
+ save_mesh(obj_path + '/object.obj')
+ else:
+ save_mesh(obj_path + '/object.obj', mtl=True, obj_path=obj_path)
+
+
+ return 1
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--object_path",
+ type=str,
+ required=True,
+ help="Path to the object file",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ required=True,
+ help="Path to the directory where the rendered images and metadata will be saved.",
+ )
+ parser.add_argument(
+ "--stamp",
+ type=int,
+ required=False,
+ help="Stamp to be used for the rendering.",
+ )
+ parser.add_argument(
+ "--tex",
+ type=bool,
+ required=False,
+ help="Save the texture.",
+ )
+ argv = sys.argv[sys.argv.index("--") + 1 :]
+ args = parser.parse_args(argv)
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ stamp = args.stamp if args.stamp else None
+ print(f'Stamp: {stamp}')
+ result = process(args.object_path, obj_path=args.output_dir, stamp=stamp, tex=args.tex)
+ # import numpy as np
+ # os.makedirs(args.output_dir, exist_ok=True) # the directory may be removed
+ # np.save(args.output_dir + '/result.npy', np.array(result))
\ No newline at end of file
diff --git a/Anymate/checkpoints/.gitkeep b/Anymate/checkpoints/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Anymate/configs/.gitkeep b/Anymate/configs/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Anymate/configs/conn.yaml b/Anymate/configs/conn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e1abff94d426050be88f3053e00a592240359562
--- /dev/null
+++ b/Anymate/configs/conn.yaml
@@ -0,0 +1,40 @@
+args:
+ aggr: max
+ checkpoint: Anymate/checkpoints
+ device: cuda
+ epochs: 200
+ finetune: true
+ gamma: 0.2
+ input_normal: false
+ logdir: Anymate/logs
+ loss: ce
+ mode: conn
+ resume: ''
+ root: Anymate/data
+ schedule: []
+ start_epoch: 0
+ test_batch: 1
+ testset: Anymate_test
+ train_batch: 16
+ trainset: Anymate_train
+ test_freq: 10
+
+optimizer:
+ weight_decay: 1.0e-05
+ lr: 0.0001
+
+model:
+ decoder: attendjoints_con_combine
+ encoder: bert
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
+ load_encoder: ''
+ num_joints: 96
+ out_channels: 3
+ width: 768
+ heads: 12
+ init_scale: 0.25
+ flash: False
+ use_checkpoint: False
+ qkv_bias: False
+ separate: False
\ No newline at end of file
diff --git a/Anymate/configs/conn_token.yaml b/Anymate/configs/conn_token.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e1abff94d426050be88f3053e00a592240359562
--- /dev/null
+++ b/Anymate/configs/conn_token.yaml
@@ -0,0 +1,40 @@
+args:
+ aggr: max
+ checkpoint: Anymate/checkpoints
+ device: cuda
+ epochs: 200
+ finetune: true
+ gamma: 0.2
+ input_normal: false
+ logdir: Anymate/logs
+ loss: ce
+ mode: conn
+ resume: ''
+ root: Anymate/data
+ schedule: []
+ start_epoch: 0
+ test_batch: 1
+ testset: Anymate_test
+ train_batch: 16
+ trainset: Anymate_train
+ test_freq: 10
+
+optimizer:
+ weight_decay: 1.0e-05
+ lr: 0.0001
+
+model:
+ decoder: attendjoints_con_combine
+ encoder: bert
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
+ load_encoder: ''
+ num_joints: 96
+ out_channels: 3
+ width: 768
+ heads: 12
+ init_scale: 0.25
+ flash: False
+ use_checkpoint: False
+ qkv_bias: False
+ separate: False
\ No newline at end of file
diff --git a/Anymate/configs/diffusion.yaml b/Anymate/configs/diffusion.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b8dae8c9f1a0237f7f215fffa34caf1357957b72
--- /dev/null
+++ b/Anymate/configs/diffusion.yaml
@@ -0,0 +1,49 @@
+args:
+ aggr: max
+ checkpoint: Anymate/checkpoints
+ device: cuda
+ epochs: 4000
+ finetune: true
+ gamma: 0.2
+ input_normal: false
+ logdir: Anymate/logs
+ loss: chamfer
+ mode: diffusion
+ resume: ''
+ root: Anymate/data
+ schedule: []
+ start_epoch: 0
+ test_batch: 1
+ testset: Anymate_test
+ train_batch: 16
+ trainset: Anymate_train
+ test_freq: 50
+ num_train_step: 100
+ num_training_points: 128
+ seed: 42
+
+optimizer:
+ weight_decay: 1.0e-05
+ lr: 0.0001
+
+model:
+ encoder: transformer
+ decoder: Cross_Attention_Diffusion
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
+ input_channels: 3
+ output_channels: 3
+ num_z: 16
+ num_x: 128
+ z_dim: 768
+ x_dim: 512
+ num_blocks: 4
+ num_compute_layers: 4
+ num_heads: 8
+ mlp_ratio: 4.0
+ qkv_bias: true
+ drop: 0.0
+ attn_drop: 0.0
+ drop_path: 0.0
+ num_latents: 16
+ use_projection: true
\ No newline at end of file
diff --git a/Anymate/configs/diffusion_concat.yaml b/Anymate/configs/diffusion_concat.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8513edcdaf655e85ab8d822f85eeb285ee7986b8
--- /dev/null
+++ b/Anymate/configs/diffusion_concat.yaml
@@ -0,0 +1,46 @@
+args:
+ aggr: max
+ checkpoint: Anymate/checkpoints
+ device: cuda
+ epochs: 4000
+ finetune: true
+ gamma: 0.2
+ input_normal: false
+ logdir: Anymate/logs
+ loss: chamfer
+ mode: diffusion
+ resume: ''
+ root: Anymate/data
+ schedule: []
+ start_epoch: 0
+ test_batch: 1
+ testset: Anymate_test
+ train_batch: 16
+ trainset: Anymate_train
+ test_freq: 1000
+ num_train_step: 100
+ num_training_points: 128
+ seed: 42
+
+optimizer:
+ weight_decay: 1.0e-05
+ lr: 0.0001
+
+model:
+ encoder: bert
+ decoder: Pointe_Diffusion
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
+ input_channels: 3
+ output_channels: 3
+ n_ctx: 128
+ width: 768
+ layers: 12
+ heads: 8
+ init_scale: 0.25
+ time_token_cond: true
+ cond_drop_prob: 0.1
+ use_projection: true
+
+
+
diff --git a/Anymate/configs/diffusion_cross.yaml b/Anymate/configs/diffusion_cross.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3cf718c365c496d1726b1279d3dfed27cc9c2deb
--- /dev/null
+++ b/Anymate/configs/diffusion_cross.yaml
@@ -0,0 +1,51 @@
+args:
+ aggr: max
+ checkpoint: Anymate/checkpoints
+ device: cuda
+ epochs: 4000
+ finetune: true
+ gamma: 0.2
+ input_normal: false
+ logdir: Anymate/logs
+ loss: chamfer
+ mode: diffusion
+ resume: ''
+ root: Anymate/data
+ schedule: []
+ start_epoch: 0
+ test_batch: 1
+ testset: Anymate_test
+ train_batch: 32
+ trainset: Anymate_train
+ test_freq: 1000
+ num_train_step: 100
+ num_training_points: 128
+ seed: 42
+
+optimizer:
+ weight_decay: 1.0e-05
+ lr: 0.0001
+
+model:
+ encoder: miche
+ decoder: Cross_Attention_Diffusion
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
+ input_channels: 3
+ output_channels: 3
+ num_z: 16
+ num_x: 128
+ z_dim: 768
+ x_dim: 512
+ num_blocks: 4
+ num_compute_layers: 4
+ num_heads: 8
+ mlp_ratio: 4.0
+ qkv_bias: true
+ drop: 0.0
+ attn_drop: 0.0
+ drop_path: 0.0
+ num_latents: 16
+ use_projection: true
+
+
diff --git a/Anymate/configs/joints.yaml b/Anymate/configs/joints.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..250725603c166362c8d946f4de949901f55511fb
--- /dev/null
+++ b/Anymate/configs/joints.yaml
@@ -0,0 +1,40 @@
+args:
+ aggr: max
+ checkpoint: Anymate/checkpoints
+ device: cuda
+ epochs: 200
+ finetune: true
+ gamma: 0.2
+ input_normal: false
+ logdir: Anymate/logs
+ loss: chamfer
+ mode: joints
+ resume: ''
+ root: Anymate/data
+ schedule: []
+ start_epoch: 0
+ test_batch: 1
+ testset: Anymate_test
+ train_batch: 16
+ trainset: Anymate_train
+ test_freq: 10
+
+optimizer:
+ weight_decay: 1.0e-05
+ lr: 0.0001
+
+model:
+ decoder: transformer_latent
+ encoder: bert
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
+ load_encoder: ''
+ num_joints: 96
+ out_channels: 3
+ width: 768
+ heads: 12
+ init_scale: 0.25
+ flash: False
+ use_checkpoint: False
+ qkv_bias: False
+ separate: False
\ No newline at end of file
diff --git a/Anymate/configs/joints_implicit.yaml b/Anymate/configs/joints_implicit.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b7869584de281e527172ae90a7414707aa6beee6
--- /dev/null
+++ b/Anymate/configs/joints_implicit.yaml
@@ -0,0 +1,40 @@
+args:
+ aggr: max
+ checkpoint: Anymate/checkpoints
+ device: cuda
+ epochs: 200
+ finetune: true
+ gamma: 0.2
+ input_normal: false
+ logdir: Anymate/logs
+ loss: chamfer
+ mode: joints
+ resume: ''
+ root: Anymate/data
+ schedule: []
+ start_epoch: 0
+ test_batch: 1
+ testset: Anymate_test
+ train_batch: 8
+ trainset: Anymate_train
+ test_freq: 10
+
+optimizer:
+ weight_decay: 1.0e-05
+ lr: 0.0001
+
+model:
+ decoder: implicit_transformer
+ encoder: bert
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
+ load_encoder: ''
+ num_joints: 96
+ out_channels: 3
+ width: 768
+ heads: 12
+ init_scale: 0.25
+ flash: False
+ use_checkpoint: False
+ qkv_bias: False
+ separate: False
\ No newline at end of file
diff --git a/Anymate/configs/joints_triplane.yaml b/Anymate/configs/joints_triplane.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cefc29b69a2d7d0cc2f61f8f469d212938c4486d
--- /dev/null
+++ b/Anymate/configs/joints_triplane.yaml
@@ -0,0 +1,40 @@
+args:
+ aggr: max
+ checkpoint: Anymate/checkpoints
+ device: cuda
+ epochs: 200
+ finetune: true
+ gamma: 0.2
+ input_normal: false
+ logdir: Anymate/logs
+ loss: chamfer
+ mode: joints
+ resume: ''
+ root: Anymate/data
+ schedule: []
+ start_epoch: 0
+ test_batch: 1
+ testset: Anymate_test
+ train_batch: 16
+ trainset: Anymate_train
+ test_freq: 10
+
+optimizer:
+ weight_decay: 1.0e-05
+ lr: 0.0001
+
+model:
+ decoder: triplane
+ encoder: bert
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
+ load_encoder: ''
+ num_joints: 96
+ out_channels: 3
+ width: 768
+ heads: 12
+ init_scale: 0.25
+ flash: False
+ use_checkpoint: False
+ qkv_bias: False
+ separate: False
\ No newline at end of file
diff --git a/Anymate/configs/skin.yaml b/Anymate/configs/skin.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..438408308f3ebb98ecc9800dbb3c48d0c1ec3399
--- /dev/null
+++ b/Anymate/configs/skin.yaml
@@ -0,0 +1,40 @@
+args:
+ aggr: max
+ checkpoint: Anymate/checkpoints
+ device: cuda
+ epochs: 200
+ finetune: true
+ gamma: 0.2
+ input_normal: false
+ logdir: Anymate/logs
+ loss: cos_clamp
+ mode: skin
+ resume: ''
+ root: Anymate/data
+ schedule: []
+ start_epoch: 0
+ test_batch: 1
+ testset: Anymate_test
+ train_batch: 16
+ trainset: Anymate_train
+ test_freq: 10
+
+optimizer:
+ weight_decay: 1.0e-05
+ lr: 0.0001
+
+model:
+ decoder: attendjoints_combine
+ encoder: bert
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
+ load_encoder: ''
+ num_joints: 96
+ out_channels: 3
+ width: 768
+ heads: 12
+ init_scale: 0.25
+ flash: False
+ use_checkpoint: False
+ qkv_bias: False
+ separate: False
\ No newline at end of file
diff --git a/Anymate/configs/skin_multi.yaml b/Anymate/configs/skin_multi.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..27c1822dd1ec04201370ff7dd4684b9736371f5a
--- /dev/null
+++ b/Anymate/configs/skin_multi.yaml
@@ -0,0 +1,40 @@
+args:
+ aggr: max
+ checkpoint: Anymate/checkpoints
+ device: cuda
+ epochs: 200
+ finetune: true
+ gamma: 0.2
+ input_normal: false
+ logdir: Anymate/logs
+ loss: cos_clamp
+ mode: skin
+ resume: ''
+ root: Anymate/data
+ schedule: []
+ start_epoch: 0
+ test_batch: 1
+ testset: Anymate_test
+ train_batch: 4
+ trainset: Anymate_train
+ test_freq: 10
+
+optimizer:
+ weight_decay: 1.0e-05
+ lr: 0.0001
+
+model:
+ decoder: attendjoints_multi
+ encoder: bert
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
+ load_encoder: ''
+ num_joints: 96
+ out_channels: 3
+ width: 768
+ heads: 12
+ init_scale: 0.25
+ flash: False
+ use_checkpoint: False
+ qkv_bias: False
+ separate: False
\ No newline at end of file
diff --git a/Anymate/dataset.py b/Anymate/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..893071f9df342855a3661aa7e34f55fd7fac3195
--- /dev/null
+++ b/Anymate/dataset.py
@@ -0,0 +1,62 @@
+import torch
+from torch.utils.data import Dataset
+import os
+import numpy as np
+from Anymate.utils.dataset_utils import create_mask, index_to_sparse, index_to_sparse_con
+
+def my_collate(batch):
+ # print(len(batch))
+ data = {}
+ for key in batch[0]:
+ if key=='vox' or key=='name' or key=='joints_num' or key=='skins_index' or key=='skins_weight' or key=='parent_index' or key=='conns' or key=='joints' or key=='bones' or key=='mesh_skins_index' or key=='mesh_skins_weight' or key=='mesh_pc' or key=='mesh_face':
+ data[key] = [sample[key] for sample in batch]
+ elif key=='pc':
+ data['points_cloud'] = torch.stack([sample['pc'] for sample in batch])
+ elif key=='skins':
+ continue
+ elif key=='bones_num':
+ data[key] = torch.tensor([sample['bones_num'] for sample in batch])
+ else:
+ data[key] = torch.stack([sample[key] for sample in batch])
+
+ if 'skins_index' in batch[0]:
+ max_joints = max(data['joints_num'])
+ max_bones = max(data['bones_num'])
+ # max_joints = 64
+ skin_list = [index_to_sparse(data['skins_index'][i].unsqueeze(0), data['skins_weight'][i].unsqueeze(0), [1, 8192, max_bones])[0] for i in range(len(data['skins_index']))]
+ data['skins'] = torch.stack(skin_list,dim=0)
+ data['joints_mask'] = torch.stack([create_mask(sample['joints_num'],max_len=max_joints) for sample in batch])
+ data['bones_mask'] = torch.stack([create_mask(sample['bones_num'],max_len=max_bones) for sample in batch])
+
+ if 'conns' in batch[0]:
+ max_joints = max(data['joints_num'])
+ conn_matrix = torch.zeros(len(data['conns']), 96, max_joints)
+ for i in range(len(data['conns'])):
+ for j in range(data['joints_num'][i]):
+ conn_matrix[i, j, data['conns'][i][j].long()] = 1
+ data['conns'] = conn_matrix
+ if 'joints' in batch[0]:
+ padded_joints_matrix = torch.ones(len(data['name']), 96, 3) * (-3)
+ for i in range(len(data['name'])):
+ padded_joints_matrix[i, :data['joints_num'][i], :] = data['joints'][i]
+ data['joints'] = padded_joints_matrix
+ if 'bones' in batch[0]:
+ padded_bones_matrix = torch.ones(len(data['name']), 64, 6) * (-3)
+ for i in range(len(data['name'])):
+ padded_bones_matrix[i, :data['bones_num'][i], :] = data['bones'][i]
+ data['bones'] = padded_bones_matrix
+ return data
+
+class AnymateDataset(Dataset):
+ def __init__(self, name='Anymate_test', root='Anymate/data'):
+
+ if os.path.exists(os.path.join(root, name) + '.pt'):
+ self.data_list = torch.load(os.path.join(root, name) + '.pt')
+ else:
+ raise ValueError('Dataset not found at path: {}'.format(os.path.join(root, name) + '.pt'))
+
+ def __len__(self):
+ return len(self.data_list)
+
+ def __getitem__(self, idx):
+ return self.data_list[idx]
\ No newline at end of file
diff --git a/Anymate/get_checkpoints.sh b/Anymate/get_checkpoints.sh
new file mode 100644
index 0000000000000000000000000000000000000000..280d8f2f852fd11511b255d1980bc28dd67c4b7b
--- /dev/null
+++ b/Anymate/get_checkpoints.sh
@@ -0,0 +1,22 @@
+cd Anymate/checkpoints
+mkdir joint
+cd joint
+
+echo "Downloading joint checkpoints..."
+wget "https://huggingface.co/yfdeng/Anymate/resolve/main/checkpoints/joint/bert-transformer_latent-train-8gpu-finetune.pth.tar?download=true" -O bert-transformer_latent-train-8gpu-finetune.pth.tar
+
+cd ..
+mkdir conn
+cd conn
+
+echo "Downloading conn checkpoints..."
+wget "https://huggingface.co/yfdeng/Anymate/resolve/main/checkpoints/conn/bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar?download=true" -O bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar
+
+cd ..
+mkdir skin
+cd skin
+
+echo "Downloading skin checkpoints..."
+wget "https://huggingface.co/yfdeng/Anymate/resolve/main/checkpoints/skin/bert-attendjoints_combine-train-8gpu-finetune.pth.tar?download=true" -O bert-attendjoints_combine-train-8gpu-finetune.pth.tar
+
+echo "Finished downloading checkpoints!"
diff --git a/Anymate/get_datasets.sh b/Anymate/get_datasets.sh
new file mode 100644
index 0000000000000000000000000000000000000000..464fe4da7e8abf0fabf5c3ac22573015933b6289
--- /dev/null
+++ b/Anymate/get_datasets.sh
@@ -0,0 +1,12 @@
+cd Anymate/data
+wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_test.pt?download=true" -O Anymate_test.pt
+wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_0.pt?download=true" -O Anymate_train_0.pt
+wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_1.pt?download=true" -O Anymate_train_1.pt
+wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_2.pt?download=true" -O Anymate_train_2.pt
+wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_3.pt?download=true" -O Anymate_train_3.pt
+wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_4.pt?download=true" -O Anymate_train_4.pt
+wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_5.pt?download=true" -O Anymate_train_5.pt
+wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_6.pt?download=true" -O Anymate_train_6.pt
+wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_7.pt?download=true" -O Anymate_train_7.pt
+
+echo "Finished downloading datasets!"
\ No newline at end of file
diff --git a/Anymate/model.py b/Anymate/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..762c88d73c8feaee371d6afd57feb41aaf91dbe2
--- /dev/null
+++ b/Anymate/model.py
@@ -0,0 +1,360 @@
+import torch
+import torch.nn as nn
+from ThirdParty.michelangelo.utils.misc import get_config_from_file, instantiate_from_config
+# from ThirdParty.PointLLM.pointllm.model.pointllm import PointLLMLlamaForCausalLM
+from ThirdParty.michelangelo.models.modules.distributions import DiagonalGaussianDistribution
+from ThirdParty.michelangelo.models.modules.embedder import components_from_spherical_harmonics
+from Anymate.utils.diffusion_encoder import TransformerEncoder
+from Anymate.models.joint import TransformerDecoder, ImplicitTransformerDecoder, TriPlaneDecoder
+from Anymate.models.conn import AttendjointsDecoder_con_combine, AttendjointsDecoder_con_token
+from Anymate.models.skin import AttendjointsDecoder_combine, AttendjointsDecoder_multi
+from Anymate.models.diffusion import Pointe_Diffusion, Cross_Attention_Diffusion
+
+class Encoder(nn.Module):
+ def __init__(self,
+ only_embed = True,
+ config_path = './ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml',
+ ckpt_path = './ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt',
+ num_latents = 257,
+ device = 'cuda'):
+
+ super().__init__()
+
+ model_config = get_config_from_file(config_path)
+ if hasattr(model_config, "model"):
+ model_config = model_config.model
+
+ if ckpt_path is not None:
+ model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
+ else:
+ model = instantiate_from_config(model_config)
+ model.model.shape_model.encoder.num_latents = num_latents
+ model.model.shape_model.encoder.query = nn.Parameter(torch.randn((num_latents, 768), device=device, dtype=torch.float32) * 0.02)
+
+ self.shape_projection = model.model.shape_projection
+ self.encoder = model.model.shape_model.encoder
+ self.normal_embedder = components_from_spherical_harmonics
+ old_linear_proj = self.encoder.input_proj
+ self.encoder.input_proj = nn.Linear(old_linear_proj.in_features + 25, old_linear_proj.out_features)
+ self.encoder.input_proj.weight.data[:, :old_linear_proj.in_features] = old_linear_proj.weight.data[:, :old_linear_proj.in_features].clone()
+ self.encoder.input_proj.bias.data = old_linear_proj.bias.data.clone()
+ if not only_embed:
+ self.embed_dim = model.model.shape_model.embed_dim
+ self.pre_kl = model.model.shape_model.pre_kl
+ self.post_kl = model.model.shape_model.post_kl
+ self.transformer = model.model.shape_model.transformer
+
+
+ def encode_latents(self,
+ pc: torch.FloatTensor,
+ feats = None):
+
+ feats_embed = self.normal_embedder(feats)
+ feats = torch.cat([feats, feats_embed], dim=-1)
+
+ x, _ = self.encoder(pc, feats)
+
+ shape_embed = x[:, 0]
+ latents = x[:, 1:]
+
+ return shape_embed, latents
+
+
+ def encode_shape_embed(self, surface, return_latents: bool = False):
+ """
+
+ Args:
+ surface (torch.FloatTensor): [bs, n, 3 + c]
+ return_latents (bool):
+
+ Returns:
+ x (torch.FloatTensor): [bs, projection_dim]
+ shape_latents (torch.FloatTensor): [bs, m, d]
+ """
+
+ pc = surface[..., 0:3]
+ feats = surface[..., 3:]
+
+ shape_embed, shape_latents = self.encode_latents(pc, feats)
+ x = shape_embed @ self.shape_projection
+
+ if return_latents:
+ return x, shape_latents
+ else:
+ return x
+
+
+ def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
+ posterior = None
+ if self.embed_dim > 0:
+ moments = self.pre_kl(latents)
+ posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
+
+ if sample_posterior:
+ kl_embed = posterior.sample()
+ else:
+ kl_embed = posterior.mode()
+ else:
+ kl_embed = latents
+
+ return kl_embed, posterior
+
+
+ def decode(self, latents: torch.FloatTensor):
+ latents = self.post_kl(latents)
+ return self.transformer(latents)
+
+
+class EncoderDecoder(nn.Module):
+ def __init__(self,
+ decoder = 'mlp',
+ encoder = 'miche',
+ config_path = './ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml',
+ ckpt_path = './ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt',
+ load_encoder = '',
+ num_joints = 96,
+ out_channels = 3,
+ width = 768,
+ device = 'cuda',
+ dtype = torch.float32,
+ heads = 12,
+ init_scale: float = 0.25,
+ flash = False,
+ use_checkpoint = False,
+ qkv_bias = False,
+ separate = False,
+ **kwargs):
+
+ super().__init__()
+ self.decoder_name = decoder
+ self.encoder_name = encoder
+ self.dtype = dtype
+ self.load_encoder = load_encoder
+
+ if decoder == 'transformer_latent':
+ self.only_embed = False
+ self.return_latents = True
+ self.decoder = TransformerDecoder(
+ num_latents = num_joints,
+ out_channels = out_channels,
+ width = width,
+ device = device,
+ dtype = dtype,
+ heads = heads,
+ init_scale = init_scale,
+ flash = flash,
+ use_checkpoint = use_checkpoint,
+ qkv_bias = qkv_bias
+ )
+ elif decoder == 'implicit_transformer':
+ self.only_embed = False
+ self.return_latents = True
+ self.decoder = ImplicitTransformerDecoder(
+ device = device,
+ dtype = dtype,
+ num_latents = 257,
+ out_channels = 1,
+ width = width,
+ heads = heads,
+ init_scale = init_scale,
+ flash = flash,
+ use_checkpoint = use_checkpoint,
+ qkv_bias = qkv_bias
+ )
+ elif decoder == 'triplane': #consider add these parameters to config
+ self.only_embed = True
+ self.return_latents = False
+ self.decoder = TriPlaneDecoder(
+ z_dim = 768,
+ c_dim = 0,
+ w_dim = 768,
+ mapping_kwargs = {'num_layers': 2},
+ synthesis_kwargs = {'num_fp16_res': 0, 'conv_clamp': None, 'fused_modconv_default': 'inference_only'}
+ )
+
+ elif decoder == 'Pointe_Diffusion':
+ self.only_embed = False
+ self.return_latents = True
+ self.decoder = Pointe_Diffusion(**kwargs)
+
+ elif decoder == 'Cross_Attention_Diffusion':
+ self.only_embed = False
+ self.return_latents = True
+ self.decoder = Cross_Attention_Diffusion(**kwargs)
+
+ elif decoder == 'attendjoints_combine':
+ self.only_embed = False
+ self.return_latents = True
+ self.decoder = AttendjointsDecoder_combine(
+ width = width,
+ device = device,
+ dtype = dtype,
+ heads = heads,
+ init_scale = init_scale,
+ flash = flash,
+ use_checkpoint = use_checkpoint,
+ separate = separate,
+ qkv_bias = qkv_bias
+ )
+ elif decoder == 'attendjoints_multi':
+ self.only_embed = False
+ self.return_latents = True
+ self.decoder = AttendjointsDecoder_multi(
+ width = width,
+ device = device,
+ dtype = dtype,
+ heads = heads,
+ init_scale = init_scale,
+ flash = flash,
+ use_checkpoint = use_checkpoint,
+ qkv_bias = qkv_bias,
+ separate=separate
+ )
+ elif decoder == 'attendjoints_con_combine':
+ self.only_embed = False
+ self.return_latents = True
+ self.decoder = AttendjointsDecoder_con_combine(
+ width = width,
+ device = device,
+ dtype = dtype,
+ heads = heads,
+ init_scale = init_scale,
+ flash = flash,
+ use_checkpoint = use_checkpoint,
+ qkv_bias = qkv_bias
+ )
+ elif decoder == 'attendjoints_con_token':
+ self.only_embed = False
+ self.return_latents = True
+ self.decoder = AttendjointsDecoder_con_token(
+ width = width,
+ device = device,
+ dtype = dtype,
+ heads = heads,
+ init_scale = init_scale,
+ flash = flash,
+ use_checkpoint = use_checkpoint,
+ qkv_bias = qkv_bias,
+ separate = separate
+ )
+
+ if encoder == 'miche':
+ if not self.load_encoder:
+ self.encoder = Encoder(only_embed=self.only_embed, config_path=config_path, ckpt_path=ckpt_path, device=device)
+ else:
+ self.encoder = Encoder(only_embed=self.only_embed, config_path=config_path, ckpt_path=None, device=device)
+ try:
+ print("=> loading encoder checkpoint '{}'".format(self.load_encoder))
+ checkpoint = torch.load(self.load_encoder, map_location='cpu')
+ state_dict = {k[8:]: v for k, v in checkpoint['state_dict'].items() if k.startswith('encoder')}
+ self.encoder.load_state_dict(state_dict)
+ print("=> loaded encoder checkpoint '{}'".format(self.load_encoder))
+ except:
+ print("=> no encoder checkpoint found at '{}'".format(self.load_encoder))
+ if self.load_encoder:
+ self.point_proj = nn.Sequential(
+ nn.Linear(768, 768, dtype=dtype),
+ nn.GELU(),
+ nn.Linear(768, 768, dtype=dtype),
+ )
+
+ if encoder == 'bert':
+ # model_name = 'RunsenXu/PointLLM_7B_v1.2'
+ # model = PointLLMLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=False, use_cache=True, torch_dtype=dtype)
+ # self.encoder = model.model.point_backbone.to(device)
+ # model = None
+ from ThirdParty.PointLLM.pointllm.model import PointTransformer
+ from ThirdParty.PointLLM.pointllm.utils import cfg_from_yaml_file
+ import os
+ # address of config file, in the same dir of this file
+ point_bert_config_name = "PointTransformer_8192point_2layer" # * default for v1.2, v1.1 uses PointTransformer_base_8192point.yaml
+ point_bert_config_addr = os.path.join("./ThirdParty/PointLLM/pointllm/model/pointbert/PointTransformer_8192point_2layer.yaml")
+ print(f"Loading PointBERT config from {point_bert_config_addr}.")
+ point_bert_config = cfg_from_yaml_file(point_bert_config_addr)
+ point_bert_config.model.point_dims = 6
+ use_max_pool = getattr(point_bert_config.model, "use_max_pool", False) # * default is false
+
+ self.encoder = PointTransformer(point_bert_config.model, use_max_pool=use_max_pool).to(device)
+ if self.return_latents:
+ self.point_proj = nn.Sequential(
+ nn.Linear(384, 512, dtype=dtype),
+ nn.GELU(),
+ nn.Linear(512, 512, dtype=dtype),
+ nn.GELU(),
+ nn.Linear(512, 768, dtype=dtype)
+ )
+ else:
+ self.point_proj = nn.ModuleList([
+ nn.Sequential(
+ nn.Linear(384, 512, dtype=dtype),
+ nn.GELU(),
+ nn.Linear(512, 512, dtype=dtype),
+ nn.GELU(),
+ nn.Linear(512, 768, dtype=dtype)
+ ),
+ nn.Linear(513, 1, dtype=dtype)
+ ])
+ if encoder == 'transformer':
+ self.points_cloud_embed = nn.Linear(
+ 768, 768, device=device, dtype=dtype
+ )
+ self.encoder = TransformerEncoder(device=device,dtype=dtype, num_latents=kwargs['num_latents'])
+
+
+
+ def encode(self, data, device='cuda'):
+ assert self.encoder_name in ['miche', 'bert', 'transformer'], f'Encoder {self.encoder_name} not supported'
+ if self.encoder_name == 'miche':
+ surface = data['points_cloud'].to(self.dtype).to(device)
+
+ # encoding
+ shape_embed, shape_latents = self.encoder.encode_shape_embed(surface, return_latents=True) # ShapeAsLatentPerceiver.encode_latents(): encoder
+
+ if self.only_embed:
+ if self.return_latents:
+ if self.load_encoder:
+ return self.point_proj(torch.cat([shape_embed.unsqueeze(1), shape_latents], dim=1))
+ return torch.cat([shape_embed.unsqueeze(1), shape_latents], dim=1) # torch.Size([bs, 257, 768]
+ return shape_embed # shape_embed: torch.Size([bs, 768])
+
+ shape_zq, posterior = self.encoder.encode_kl_embed(shape_latents) # ShapeAsLatentPerceiver.encode_kl_embed(): pre_kl + DiagonalGaussianDistribution()
+ # shape_zq, posterior = self.encoder.encode_kl_embed(shape_latents, sample_posterior=False) # not sample
+ # pretrained weight has 0 +- 0.7 mean and 0.5 +- 0.5 std
+ # trained weight has 0 +- 1.8 mean and 0.1 +- 0.1 std
+ # generally okay
+
+ latents = self.encoder.decode(shape_zq) # ShapeAsLatentPerceiver.decode(): post_kl + transformer
+
+ if not self.return_latents:
+ latents = torch.cat([shape_latents, latents], dim=1) # torch.Size([bs, 512, 768])
+
+ if self.load_encoder:
+ return self.point_proj(torch.cat([shape_embed.unsqueeze(1), latents], dim=1))
+ return torch.cat([shape_embed.unsqueeze(1), latents], dim=1) # torch.Size([bs, 257 / 513, 768])
+
+ if self.encoder_name == 'bert':
+ points = data['points_cloud'].to(self.dtype).to(device)
+ points = points[:, :, :3] / 2
+ points = torch.cat([points, torch.zeros_like(points)], dim=-1)
+ points = self.encoder(points)
+
+ if self.return_latents:
+ points = self.point_proj(points)
+ else:
+ points = self.point_proj[0](points)
+ points = self.point_proj[1](points.permute(0, 2, 1)).squeeze(-1)
+ return points
+
+ if self.encoder_name == 'transformer':
+ points = data['points_cloud'].to(self.dtype).to(device)
+ cond = self.encoder.encode_pc(points)
+ cond = self.points_cloud_embed(cond)
+ return cond
+
+ def forward(self, data, device='cuda', downsample=False, **kwargs):
+ latents = self.encode(data, device)
+ # print('latents shape', latents.shape)
+
+ logits = self.decoder(latents, data, device=device, downsample=downsample,**kwargs)
+
+ return logits
\ No newline at end of file
diff --git a/Anymate/models/__init__.py b/Anymate/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Anymate/models/conn.py b/Anymate/models/conn.py
new file mode 100644
index 0000000000000000000000000000000000000000..433b85f258e6409ecb88144017f290220cb4d6f0
--- /dev/null
+++ b/Anymate/models/conn.py
@@ -0,0 +1,195 @@
+import torch
+import torch.nn as nn
+from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock, ResidualAttentionBlock, Transformer
+from ThirdParty.michelangelo.models.modules.embedder import FourierEmbedder, components_from_spherical_harmonics
+
+class AttendjointsDecoder_con_combine(nn.Module):
+ def __init__(self,
+ width = 768,
+ layers = 2,
+ device = 'cuda',
+ dtype = torch.float32,
+ heads = 12,
+ init_scale: float = 0.25,
+ flash = False,
+ use_checkpoint = False,
+ qkv_bias = False,
+ num_freqs: int = 8,
+ include_pi: bool = True,
+ separate = False,
+ use_mask = True):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+ self.separate = separate
+ self.use_mask = use_mask
+ # self.num_latents = num_latents
+
+ # self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
+
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+ self.co_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
+
+ # self.proj_attn = nn.Linear(width, width, device=device, dtype=dtype)
+
+ self.cross_attn = nn.ModuleList([ResidualCrossAttentionBlock(
+ device=device,
+ dtype=dtype,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ ) for _ in range(layers)])
+
+ self.self_attn = nn.ModuleList([ResidualAttentionBlock(
+ device=device,
+ dtype=dtype,
+ n_ctx=-1,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ ) for _ in range(layers * 2)])
+
+ # self.joint_embed_proj = nn.ModuleList([nn.Linear(width, width, device=device, dtype=dtype) for _ in range(layers)])
+
+
+ self.q_proj = nn.Linear(width, width, device=device, dtype=dtype)
+ self.k_proj = nn.Linear(width, width, device=device, dtype=dtype)
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
+
+ # self.last_cross_attn = ResidualCrossAttentionBlock(
+ # device=device,
+ # dtype=dtype,
+ # width=width,
+ # heads=heads,
+ # init_scale=init_scale,
+ # qkv_bias=qkv_bias,
+ # flash=flash,
+ # )
+ # self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
+ # self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype)
+
+ def forward(self, latents, data=None, device='cuda', downsample=None, dtype=torch.float32):
+
+ joints = data['joints'].to(device)
+ max_joints = max(data['joints_num'])
+ joints = joints[:, :max_joints, :3]
+
+ joints_embeds = self.fourier_embedder(joints)
+ joints_embeds = self.co_proj(joints_embeds)
+
+ joints_num = joints_embeds.shape[-2]
+
+ x = [joints_embeds, joints_embeds.clone()]
+
+ for i in range(2):
+ for j, layer in enumerate(self.cross_attn):
+
+ x[i] = layer(x[i], latents)
+
+ if self.use_mask:
+ x[i] = self.self_attn[2*i+j](x[i], mask=data['joints_mask'].to(device))
+ else:
+ x[i] = self.self_attn[2*i+j](x[i])
+
+ # Dot Product between points and joints
+ logits = torch.einsum('bnc,bmc->bnm', self.k_proj(self.ln_1(x[0])), self.q_proj(self.ln_2(x[1]))) # (b, n, m)
+
+ if self.use_mask:
+ mask = data['joints_mask'].to(device)
+ logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
+
+ return logits
+
+class AttendjointsDecoder_con_token(nn.Module):
+ def __init__(self,
+ width = 768,
+ layers = 4,
+ device = 'cuda',
+ dtype = torch.float32,
+ heads = 12,
+ init_scale: float = 0.25,
+ flash = False,
+ use_checkpoint = False,
+ qkv_bias = False,
+ num_freqs: int = 8,
+ include_pi: bool = True,
+ head_token_length =128,
+ separate = False,
+ use_mask = True):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+ self.use_mask = use_mask
+ self.layer_norm = nn.LayerNorm(width)
+ self.head_token = nn.Parameter(torch.randn((1, 1, head_token_length), device=device, dtype=dtype) * 0.02)
+ self.tail_token = nn.Parameter(torch.randn((1, 1, head_token_length), device=device, dtype=dtype) * 0.02)
+ self.head_mlp = nn.ModuleList([
+ nn.Linear(width + head_token_length, 512, device=device, dtype=dtype),
+ nn.Linear(512, 512, device=device, dtype=dtype),
+ nn.Linear(512, width, device=device, dtype=dtype),
+ nn.LayerNorm(width)
+
+ ])
+ self.tail_mlp = nn.ModuleList([
+ nn.Linear(width + head_token_length, 512, device=device, dtype=dtype),
+ nn.Linear(512, 512, device=device, dtype=dtype),
+ nn.Linear(512, width, device=device, dtype=dtype),
+ nn.LayerNorm(width)
+ ])
+
+ self.self_attn = Transformer(
+ device=device,
+ dtype=dtype,
+ n_ctx=-1,
+ width=width,
+ layers=layers,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_checkpoint=False,
+ )
+ self.separate = separate
+ self.normal_embedder = components_from_spherical_harmonics
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+ self.joints_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
+ self.output_proj_joints = nn.Linear(width, width, device=device, dtype=dtype)
+
+ def forward(self, latents, data=None,device='cuda', downsample=None, dtype='float32'):
+ joints = data['joints'].to(device)
+ max_joints = max(data['joints_num'])
+ joints = joints[:, :max_joints, :3]
+ joints_embeds_fourier = self.fourier_embedder(joints)
+ joints_embeds = self.joints_proj(joints_embeds_fourier)
+ # Concatenate embeddings
+ x = torch.cat([joints_embeds, latents], dim=-2) # (b, max_joint+token_num, c)
+ # Pass through self-attention
+ if self.use_mask:
+ mask = data['mask'].to(device)
+ append_size = x.shape[1]-mask.shape[1] # the zero needs to append after mask
+ batch_size = mask.shape[0]
+
+ mask_extend = torch.ones((batch_size,append_size)).to(device)
+ mask = torch.cat([mask,mask_extend],dim=-1).to(device)
+
+ x = self.self_attn(x,mask)
+ else:
+ x = self.self_attn(x)
+ joints, _= x.split([joints_embeds.shape[1], latents.shape[1]], dim=1)
+ joints = self.output_proj_joints(self.layer_norm(joints))
+ joints_head = torch.concat([joints, self.head_token.repeat(joints.shape[0],joints.shape[1],1)], dim=-1)
+ joints_tail = torch.concat([joints, self.tail_token.repeat(joints.shape[0],joints.shape[1],1)], dim=-1)
+ for layer in self.head_mlp:
+ joints_head = layer(joints_head)
+ for layer in self.tail_mlp:
+ joints_tail = layer(joints_tail)
+ logits = torch.einsum('bik,bjk->bij', joints_head, joints_tail)
+
+ return logits
\ No newline at end of file
diff --git a/Anymate/models/diffusion.py b/Anymate/models/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..84fb6c021d1f0e103e09ba5dc608ce193b8431aa
--- /dev/null
+++ b/Anymate/models/diffusion.py
@@ -0,0 +1,483 @@
+F"""
+Adapted from: https://github.com/openai/openai/blob/55363aa496049423c37124b440e9e30366db3ed6/orc/orc/diffusion/vit.py
+"""
+
+import math
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+
+from einops import repeat
+from Anymate.utils.diffusion_utils import *
+from ThirdParty.michelangelo.models.modules.transformer_blocks import Transformer, ResidualCrossAttentionBlock
+
+from diffusers import DDPMScheduler, DDIMScheduler
+from sklearn.cluster import DBSCAN
+
+def init_linear(l, stddev):
+ nn.init.normal_(l.weight, std=stddev)
+ if l.bias is not None:
+ nn.init.constant_(l.bias, 0.0)
+
+class projection_transformer(nn.Module):
+ def __init__(self, num_latents=16, width = 16, heads=8, dtype = torch.float32):
+ super().__init__()
+ self.num_latents = num_latents
+ self.query = nn.Parameter(torch.randn((num_latents, width), dtype=dtype) * 0.02)
+
+ self.cross_attn = ResidualCrossAttentionBlock(
+ device= 'cuda',
+ dtype=dtype,
+ width=width,
+ heads=heads,
+ init_scale=0.25,
+ qkv_bias=True,
+ flash=False,
+ )
+ self.output_proj = nn.Linear(width, width,dtype=dtype)
+
+ def forward(self, latents):
+ bs = latents.shape[0]
+ query = repeat(self.query, "m c -> b m c", b=bs)
+ embed = self.cross_attn(query, latents)
+ logits = self.output_proj(embed)
+
+ return logits
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+class MultiheadAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ dtype: torch.dtype,
+ n_ctx: int,
+ width: int,
+ heads: int,
+ init_scale: float,
+ ):
+ super().__init__()
+ self.n_ctx = n_ctx
+ self.width = width
+ self.heads = heads
+ self.c_qkv = nn.Linear(width, width * 3, dtype=dtype)
+ self.c_proj = nn.Linear(width, width, dtype=dtype)
+ self.attention = QKVMultiheadAttention(dtype=dtype, heads=heads, n_ctx=n_ctx)
+ init_linear(self.c_qkv, init_scale)
+ init_linear(self.c_proj, init_scale)
+
+ def forward(self, x):
+ x = self.c_qkv(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x
+
+class MLP(nn.Module):
+ def __init__(self, *, dtype: torch.dtype, width: int, init_scale: float):
+ super().__init__()
+ self.width = width
+ self.c_fc = nn.Linear(width, width * 4, dtype=dtype)
+ self.c_proj = nn.Linear(width * 4, width, dtype=dtype)
+ self.gelu = nn.GELU()
+ init_linear(self.c_fc, init_scale)
+ init_linear(self.c_proj, init_scale)
+
+ def forward(self, x):
+ return self.c_proj(self.gelu(self.c_fc(x)))
+
+class QKVMultiheadAttention(nn.Module):
+ def __init__(self, *, dtype: torch.dtype, heads: int, n_ctx: int):
+ super().__init__()
+ self.dtype = dtype
+ self.heads = heads
+ self.n_ctx = n_ctx
+
+ def forward(self, qkv):
+ bs, n_ctx, width = qkv.shape
+ attn_ch = width // self.heads // 3
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
+ weight = torch.einsum(
+ "bthc,bshc->bhts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ wdtype = weight.dtype
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ dtype: torch.dtype,
+ n_ctx: int,
+ width: int,
+ heads: int,
+ init_scale: float = 1.0,
+ ):
+ super().__init__()
+
+ self.attn = MultiheadAttention(
+ dtype=dtype,
+ n_ctx=n_ctx,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ )
+ self.ln_1 = nn.LayerNorm(width, dtype=dtype)
+ self.mlp = MLP(dtype=dtype, width=width, init_scale=init_scale)
+ self.ln_2 = nn.LayerNorm(width, dtype=dtype)
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attn(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ *,
+ dtype: torch.dtype,
+ n_ctx: int,
+ width: int,
+ layers: int,
+ heads: int,
+ init_scale: float = 0.25,
+ ):
+ super().__init__()
+ self.n_ctx = n_ctx
+ self.width = width
+ self.layers = layers
+ init_scale = init_scale * math.sqrt(1.0 / width)
+ self.resblocks = nn.ModuleList(
+ [
+ ResidualAttentionBlock(
+ dtype=dtype,
+ n_ctx=n_ctx,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ )
+ for _ in range(layers)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor):
+ for block in self.resblocks:
+ x = block(x)
+ return x
+
+class PointDiffusionTransformer(nn.Module):
+ def __init__(
+ self,
+ *,
+ dtype: torch.dtype,
+ input_channels: int = 3,
+ output_channels: int = 3,
+ n_ctx: int = 1024,
+ width: int = 768,
+ layers: int = 12,
+ heads: int = 8,
+ init_scale: float = 0.25,
+ time_token_cond: bool = True,
+ ):
+ super().__init__()
+ self.input_channels = input_channels
+ self.output_channels = output_channels
+ self.n_ctx = n_ctx
+ self.time_token_cond = time_token_cond
+ self.time_embed = MLP(
+ dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
+ )
+ self.ln_pre = nn.LayerNorm(width, dtype=dtype)
+ self.backbone = Transformer(
+ dtype=dtype,
+ n_ctx=n_ctx + int(time_token_cond),
+ width=width,
+ layers=layers,
+ heads=heads,
+ init_scale=init_scale,
+ )
+ self.ln_post = nn.LayerNorm(width,dtype=dtype)
+ self.input_proj = nn.Linear(input_channels, width, dtype=dtype)
+ self.output_proj = nn.Linear(width, output_channels,dtype=dtype)
+ with torch.no_grad():
+ self.output_proj.weight.zero_()
+ self.output_proj.bias.zero_()
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor):
+ """
+ :param x: an [N x C x T] tensor.
+ :param t: an [N] tensor.
+ :return: an [N x C' x T] tensor.
+ """
+ assert x.shape[-1] == self.n_ctx
+ t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
+ return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])
+
+ def _forward_with_cond(
+ self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]]
+ ) -> torch.Tensor:
+ h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC
+ for emb, as_token in cond_as_token:
+ if not as_token:
+ h = h + emb[:, None]
+ extra_tokens = [
+ (emb[:, None] if len(emb.shape) == 2 else emb)
+ for emb, as_token in cond_as_token
+ if as_token
+ ]
+ if len(extra_tokens):
+ h = torch.cat(extra_tokens + [h], dim=1)
+
+ h = self.ln_pre(h)
+ h = self.backbone(h)
+ h = self.ln_post(h)
+ if len(extra_tokens):
+ h = h[:, sum(h.shape[1] for h in extra_tokens) :]
+ h = self.output_proj(h)
+ return h.permute(0, 2, 1)
+
+class Pointe_Diffusion(PointDiffusionTransformer):
+ '''
+ input: data: data dict
+ x: [N x C x T] tensor
+ t: [N] tensor
+ init:
+ n_ctx: int = 1024: context length
+ '''
+ def __init__(
+ self,
+ *,
+ device = 'cuda',
+ dtype = torch.float32,
+ encoder = 'miche',
+ n_ctx: int = 1024,
+ token_cond: bool = True,
+ cond_drop_prob: float = 0.1,
+ fix_emb: bool = False,
+
+ **kwargs,
+ ):
+ super().__init__(dtype=dtype, n_ctx=n_ctx + int(token_cond), **kwargs)
+ self.n_ctx = n_ctx
+ self.token_cond = token_cond
+ # self.proj_transformer = projection_transformer(**kwargs)
+ self.encoder_name = encoder
+ self.cond_drop_prob = cond_drop_prob
+ self.fix_emb = fix_emb
+ self.dtype = dtype
+ self.inference = False
+ def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
+ with torch.no_grad():
+ return dict(embeddings=self.clip(batch_size, **model_kwargs))
+
+ def inference_mode(self,eps=0.03):
+ self.inference = True
+
+ def forward_func(
+ self,
+ latent: torch.Tensor,
+ data,
+ device='cuda',
+ downsample = False,
+ **kwargs,
+ ):
+ t = kwargs['timesteps'].to(latent.device)
+ x = kwargs['noisy_joints'].to(latent.device)
+ assert x.shape[-1] == self.n_ctx, f"x shape: {x.shape}, n_ctx: {self.n_ctx}"
+ t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
+
+ if self.training:
+ mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
+ latent = latent * mask[:,None,None].to(latent.device)
+
+ latent = [(latent, self.token_cond), (t_embed, self.time_token_cond)]
+ return self._forward_with_cond(x, latent)
+
+ def forward(self, latent, data, device='cuda', downsample = False, **kwargs):
+ if self.inference == False:
+ return self.forward_func(latent, data, device, downsample, **kwargs)
+ else:
+ generator=torch.Generator(device='cpu')
+ scheduler = DDIMScheduler(100)
+ scheduler.set_timesteps(100)
+ points_shape = [1, self.n_ctx, 3]
+
+ points_noise = randn_tensor(points_shape, generator=generator)
+ points = points_noise.permute(0, 2, 1).to(latent.device)
+ for t in scheduler.timesteps:
+ with torch.no_grad():
+ time_steps = torch.ones(1, 1, dtype=torch.long) * t
+ model_output = self.forward_func(latent, data, noisy_joints=points, timesteps = time_steps)
+
+ points = scheduler.step(model_output, t, points, generator=generator).prev_sample
+ points = points.permute(0, 2, 1).cpu()
+ assert points.shape[0] == 1, "Inference mode only supports batch size 1"
+ joints = points[0].detach().cpu().numpy()
+ clustering = DBSCAN(eps=0.05, min_samples=1).fit(joints)
+ cluster_centers = []
+ for cluster in set(clustering.labels_):
+ cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
+ return cluster_centers
+
+class Cross_Attention_Diffusion(nn.Module):
+ def __init__(self,
+ input_channels=3, output_channels=3,
+ num_z=16, num_x=1024, z_dim=768, x_dim=512,
+ num_blocks=6, num_compute_layers=4, num_heads=8,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,num_latents=16,
+ device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
+ use_projection = True,):
+ super().__init__()
+ self.use_projection = use_projection
+ self.device = device
+ self.num_z = num_z
+ self.num_x = num_x
+ self.z_dim = z_dim
+ if use_projection:
+ self.proj_transformer = projection_transformer(num_latents=num_latents, width=z_dim, heads=num_heads)
+ self.prev_latent = nn.Parameter(torch.zeros(1, self.num_z + num_latents + 1, z_dim))
+ self.inference = False
+
+ self.input_proj = nn.Linear(input_channels, x_dim)
+ self.ln_pre = nn.LayerNorm(x_dim)
+ self.z_init = nn.Parameter(torch.zeros(1, num_z, z_dim))
+
+ mlp_hidden_dim = int(z_dim * mlp_ratio)
+ self.time_embed = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim)
+
+ self.latent_mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.ln_latent = nn.LayerNorm(z_dim)
+ self.blocks = nn.ModuleList([
+ RCW_Block(z_dim, x_dim, num_compute_layers=num_compute_layers,
+ num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop, drop_path=drop_path,
+ act_layer=act_layer, norm_layer=norm_layer)
+ for _ in range(num_blocks)
+ ])
+
+ # output blocks
+ self.ln_post = nn.LayerNorm(x_dim)
+ self.output_proj = nn.Linear(x_dim, output_channels)
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ nn.init.normal_(self.z_init, std=.02)
+
+ # initialize nn.Linear and nn.LayerNorm
+ self.apply(self._init_weights)
+
+ nn.init.constant_(self.ln_latent.weight, 0)
+ nn.init.constant_(self.ln_latent.bias, 0)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def inference_mode(self,eps=0.03):
+ self.inference = True
+
+ def forward_func(self, latent, data, device='cuda', downsample = False, **kwargs):
+ """
+ Forward pass of the model.
+
+ Parameters:
+ x: [B, num_x, C_in]
+ t: [B]
+ cond: [B, num_cond, C_latent]
+ prev_latent: [B, num_z + num_cond + 1, C_latent]
+
+ Returns:
+ x_denoised: [B, num_x, C_out]
+ z: [B, num_z + num_cond + 1, C_latent]
+ """
+ t = kwargs['timesteps'].to(latent.device)
+ x = kwargs['noisy_joints'].to(latent.device)
+ x = x.permute(0, 2, 1)
+ B, num_x, _ = x.shape
+ if self.use_projection:
+ latent = self.proj_transformer(latent)
+ assert num_x == self.num_x, f"x shape: {x.shape}, num_x: {self.num_x}"
+ # if prev_latent is not None:
+ # _, num_z, _ = prev_latent.shape
+ # assert num_z == self.num_z + num_cond + 1
+ # else:
+ # prev_latent = torch.zeros(B, self.num_z + num_cond + 1, self.z_dim).to(x.device)
+
+ # timestep embedding, [B, 1, z_dim]
+ t_embed = self.time_embed(timestep_embedding(t, self.z_dim))
+ if t_embed.dim() == 2:
+ t_embed = t_embed.unsqueeze(1)
+
+ # project x -> [B, num_x, C_x]
+ x = self.input_proj(x)
+ x = self.ln_pre(x)
+
+ # latent self-conditioning
+ z = self.z_init.repeat(B, 1, 1) # [B, num_z, z_dim
+ z = torch.cat([z, latent, t_embed], dim=1) # [B, num_z + num_cond + 1, z_dim]
+ prev_latent = self.prev_latent + self.latent_mlp(self.prev_latent.detach())
+ z = z + (self.ln_latent(prev_latent))
+
+ # compute
+ for blk in self.blocks:
+ z, x = blk(z, x)
+
+ # output proj
+ x = self.ln_post(x)
+ x_denoised = self.output_proj(x)
+ return x_denoised.permute(0, 2, 1)
+
+ def forward(self, latent, data, device='cuda', downsample = False, **kwargs):
+ if self.inference == False:
+ return self.forward_func(latent, data, device, downsample, **kwargs)
+ else:
+ generator=torch.Generator(device='cpu')
+ scheduler = DDIMScheduler(100)
+ scheduler.set_timesteps(100)
+ points_shape = [1, self.num_x, 3]
+
+ points_noise = randn_tensor(points_shape, generator=generator)
+ points = points_noise.permute(0, 2, 1).to(latent.device)
+ for t in scheduler.timesteps:
+ with torch.no_grad():
+ time_steps = torch.ones(1, 1, dtype=torch.long) * t
+ time_steps = time_steps.to(latent.device)
+ model_output = self.forward_func(latent, data, noisy_joints=points, timesteps = time_steps)
+
+ points = scheduler.step(model_output, t, points, generator=generator).prev_sample
+ points = points.permute(0, 2, 1).cpu()
+ assert points.shape[0] == 1, "Inference mode only supports batch size 1"
+ joints = points[0].detach().cpu().numpy()
+ clustering = DBSCAN(eps=0.05, min_samples=1).fit(joints)
+ cluster_centers = []
+ for cluster in set(clustering.labels_):
+ cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
+ return cluster_centers
+
\ No newline at end of file
diff --git a/Anymate/models/joint.py b/Anymate/models/joint.py
new file mode 100644
index 0000000000000000000000000000000000000000..e78adf8a2c844c0c3c1d3584589abd9acf78451c
--- /dev/null
+++ b/Anymate/models/joint.py
@@ -0,0 +1,282 @@
+import torch
+import torch.nn as nn
+from ThirdParty.michelangelo.models.modules.embedder import FourierEmbedder
+from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock
+from ThirdParty.eg3d.training.networks_stylegan2 import Generator as StyleGAN2Backbone
+from ThirdParty.eg3d.training.networks_stylegan2 import FullyConnectedLayer
+from Anymate.utils.vol_utils import get_co, sample_from_planes, generate_planes
+from einops import repeat
+from sklearn.cluster import DBSCAN
+from Anymate.utils.vol_utils import extract_keypoints
+
+class TransformerDecoder(nn.Module):
+ def __init__(self,
+ num_latents = 96,
+ num_kv_latents = 257,
+ out_channels = 3,
+ width = 768,
+ layers = 7,
+ device = 'cuda',
+ dtype = torch.float32,
+ heads = 12,
+ init_scale: float = 0.25,
+ flash = False,
+ use_checkpoint = False,
+ qkv_bias = False):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+ self.num_latents = num_latents
+ self.inference = False
+ self.eps = 0.03
+
+ self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
+
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
+ device=device,
+ dtype=dtype,
+ n_data=num_kv_latents,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash
+ )
+
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
+ self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
+
+ def inference_mode(self, eps=0.03, min_samples=1):
+ self.inference = True
+ self.eps = eps
+ self.min_samples = min_samples
+
+ def forward(self, latents, data=None, device='cuda', downsample=False, dtype=torch.float32):
+
+ bs = latents.shape[0]
+ query = repeat(self.query, "m c -> b m c", b=bs)
+ logits = self.cross_attn_decoder(query, latents)
+ logits = self.ln_post(logits)
+ logits = self.output_proj(logits)
+ if self.inference:
+ assert logits.shape[0] == 1, "Inference mode only supports batch size 1"
+ joints = logits[0].detach().cpu().numpy()
+ clustering = DBSCAN(eps=self.eps, min_samples=self.min_samples).fit(joints)
+ cluster_centers = []
+ for cluster in set(clustering.labels_):
+ cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
+ return cluster_centers
+ return logits
+
+
+class ImplicitTransformerDecoder(nn.Module):
+
+ def __init__(self, *,
+ device = 'cuda',
+ dtype = torch.float32,
+ num_latents = 257,
+ out_channels = 1,
+ width = 768,
+ heads = 12,
+ num_freqs: int = 8,
+ include_pi: bool = True,
+ init_scale: float = 0.25,
+ qkv_bias: bool = False,
+ flash: bool = False,
+ use_checkpoint: bool = False):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+ self.inference = False
+
+ self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
+
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
+ device=device,
+ dtype=dtype,
+ n_data=num_latents,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash
+ )
+
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
+ self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
+
+ # self.queries = get_vol().to(device)
+
+ def inference_mode(self):
+ self.inference = True
+
+ def forward(self, latents: torch.FloatTensor, data=None, device='cuda', downsample=False):
+ bs = latents.shape[0]
+ # queries = repeat(self.queries, "m c -> b m c", b=bs)
+ out = []
+ for b in range(bs):
+ queries = get_co(data['vox'][b]).to(device).unsqueeze(0)
+ if downsample and data['vox'][b].shape[0] > 50000:
+ # random sample
+ idx = torch.randperm(data['vox'][b].shape[0])[:50000]
+ queries = queries[:, idx]
+ queries = self.query_proj(self.fourier_embedder(queries))
+ x = self.cross_attn_decoder(queries, latents[b:b+1])
+ x = self.ln_post(x)
+ x = self.output_proj(x)
+ if downsample and data['vox'][b].shape[0] > 50000:
+ out.append((x.squeeze(0), idx))
+ else:
+ out.append(x.squeeze(0))
+ if self.inference:
+ assert len(out) == 1, "Inference mode only supports batch size 1"
+ return extract_keypoints(out[0], data['vox'][0])
+
+ return out
+
+
+class TriPlaneDecoder(torch.nn.Module):
+ def __init__(self,
+ z_dim = 768, # Input latent (Z) dimensionality.
+ c_dim = 0, # Conditioning label (C) dimensionality.
+ w_dim = 768, # Intermediate latent (W) dimensionality.
+ # img_resolution, # Output resolution.
+ # img_channels, # Number of output color channels.
+ # sr_num_fp16_res = 0,
+ mapping_kwargs = {'num_layers': 2}, # Arguments for MappingNetwork.
+ # rendering_kwargs = {},
+ # sr_kwargs = {},
+ synthesis_kwargs = {'num_fp16_res': 0, 'conv_clamp': None, 'fused_modconv_default': 'inference_only'}, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim=z_dim
+ self.c_dim=c_dim
+ self.w_dim=w_dim
+ # self.img_resolution=img_resolution
+ # self.img_channels=img_channels
+ # self.renderer = ImportanceRenderer()
+ # self.ray_sampler = RaySampler()
+ self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
+ # self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs)
+ self.decoder = OSGDecoder(32, {'decoder_output_dim': 0})
+ self.inference = False
+ # self.neural_rendering_resolution = 64
+ # self.rendering_kwargs = rendering_kwargs
+
+ self._last_planes = None
+ self.plane_axes = generate_planes()
+
+ def mapping(self, z, c=None, truncation_psi=1, truncation_cutoff=None, update_emas=False):
+ # if self.rendering_kwargs['c_gen_conditioning_zero']:
+ # c = torch.zeros_like(c)
+ # return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ return self.backbone.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+
+ def synthesis(self, ws, c=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
+ # cam2world_matrix = c[:, :16].view(-1, 4, 4)
+ # intrinsics = c[:, 16:25].view(-1, 3, 3)
+
+ # if neural_rendering_resolution is None:
+ # neural_rendering_resolution = self.neural_rendering_resolution
+ # else:
+ # self.neural_rendering_resolution = neural_rendering_resolution
+
+ # Create a batch of rays for volume rendering
+ # ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution)
+
+ # Create triplanes by running StyleGAN backbone
+ # N, M, _ = ray_origins.shape
+ if use_cached_backbone and self._last_planes is not None:
+ planes = self._last_planes
+ else:
+ planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ if cache_backbone:
+ self._last_planes = planes
+
+ # Reshape output into three 32-channel planes
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
+ return planes
+
+ # Perform volume rendering
+ feature_samples, depth_samples, weights_samples = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) # channels last
+
+ # Reshape into 'raw' neural-rendered image
+ H = W = self.neural_rendering_resolution
+ feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
+ depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
+
+ # Run superresolution to get final image
+ rgb_image = feature_image[:, :3]
+ sr_image = self.superresolution(rgb_image, feature_image, ws, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
+
+ return {'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image}
+
+ def sample(self, coordinates, directions, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
+ # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
+ return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
+
+ def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
+ # Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
+ planes = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs)
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
+ return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
+
+ def inference_mode(self):
+ self.inference = True
+
+ def forward(self, z, data=None, device='cuda', downsample=False, c=None, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
+ # Render a batch of generated images.
+ assert z.shape[-1] == self.z_dim
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ planes = self.synthesis(ws, c, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs)
+ bs = planes.shape[0]
+ logits = []
+ for b in range(bs):
+ queries = get_co(data['vox'][b]).to(device).unsqueeze(0)
+ if downsample and data['vox'][b].shape[0] > 50000:
+ # random sample
+ idx = torch.randperm(data['vox'][b].shape[0])[:50000]
+ queries = queries[:, idx]
+ out = sample_from_planes(self.plane_axes.to(device), planes[b:b+1], queries)
+ out = self.decoder(out)
+ if downsample and data['vox'][b].shape[0] > 50000:
+ logits.append((out.squeeze(0), idx))
+ else:
+ logits.append(out.squeeze(0))
+ if self.inference:
+ assert len(logits) == 1, "Inference mode only supports batch size 1"
+ return extract_keypoints(logits[0], data['vox'][0])
+ return logits
+
+
+class OSGDecoder(torch.nn.Module):
+ def __init__(self, n_features, options):
+ super().__init__()
+ self.hidden_dim = 64
+
+ self.net = torch.nn.Sequential(
+ FullyConnectedLayer(n_features, self.hidden_dim),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'])
+ )
+
+ def forward(self, sampled_features, ray_directions=None):
+ # Aggregate features
+ sampled_features = sampled_features.mean(1)
+ x = sampled_features
+
+ N, M, C = x.shape
+ x = x.view(N*M, C)
+
+ x = self.net(x)
+ x = x.view(N, M, -1)
+ return x
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+ sigma = x[..., 0:1]
+ return {'rgb': rgb, 'sigma': sigma}
\ No newline at end of file
diff --git a/Anymate/models/skin.py b/Anymate/models/skin.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec81a907d6b4e15d321cc45a1b93e86da6bd2e8f
--- /dev/null
+++ b/Anymate/models/skin.py
@@ -0,0 +1,309 @@
+import torch
+import torch.nn as nn
+from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock, Transformer
+from ThirdParty.michelangelo.models.modules.embedder import components_from_spherical_harmonics, FourierEmbedder
+from einops import repeat, rearrange
+
+class AttendjointsDecoder_combine(nn.Module):
+ def __init__(self,
+ width = 768,
+ layers = 2,
+ device = 'cuda',
+ dtype = torch.float32,
+ heads = 12,
+ init_scale: float = 0.25,
+ flash = False,
+ use_checkpoint = False,
+ qkv_bias = False,
+ num_freqs: int = 8,
+ include_pi: bool = True,
+ separate = False,
+ use_mask = True,
+ use_bone = True,
+ inference= False):
+
+ super().__init__()
+ self.inference = inference
+ self.use_checkpoint = use_checkpoint
+ self.separate = separate
+ self.use_mask = use_mask
+ # self.num_latents = num_latents
+
+ # self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
+
+ self.normal_embedder = components_from_spherical_harmonics
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+ self.bone_proj = None if not use_bone else nn.Linear(self.fourier_embedder.out_dim * 2, width, device=device, dtype=dtype)
+ self.use_bone = use_bone
+
+ if not self.separate:
+ self.co_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
+ self.normal_proj = nn.Linear(25, width, device=device, dtype=dtype)
+ else:
+ self.pc_proj = nn.Linear(self.fourier_embedder.out_dim + 25, width, device=device, dtype=dtype)
+
+
+ # self.proj_attn = nn.Linear(width, width, device=device, dtype=dtype)
+
+ self.cross_attn = nn.ModuleList([ResidualCrossAttentionBlock(
+ device=device,
+ dtype=dtype,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ ) for _ in range(layers)])
+
+ self.cross_attn_joint = nn.ModuleList([ResidualCrossAttentionBlock(
+ device=device,
+ dtype=dtype,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ ) for _ in range(layers)])
+
+ # self.joint_embed_proj = nn.ModuleList([nn.Linear(width, width, device=device, dtype=dtype) for _ in range(layers)])
+
+
+ self.q_proj = nn.Linear(width, width, device=device, dtype=dtype)
+ self.k_proj = nn.Linear(width, width, device=device, dtype=dtype)
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
+
+ # self.last_cross_attn = ResidualCrossAttentionBlock(
+ # device=device,
+ # dtype=dtype,
+ # width=width,
+ # heads=heads,
+ # init_scale=init_scale,
+ # qkv_bias=qkv_bias,
+ # flash=flash,
+ # )
+ # self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
+ # self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype)
+
+ def forward(self, latents, data=None, device='cuda', downsample=None, dtype=torch.float32):
+ joints = data['bones'].to(device) if self.use_bone else data['joints'].to(device)
+ max_joints = max(data['bones_num']) if self.use_bone else max(data['joints_num'])
+ mask = data['bones_mask'].to(device) if self.use_bone else data['joints_mask']
+
+ pc = data['vertices'][..., 0:3].to(device) if self.inference else data['points_cloud'][..., 0:3].to(device)
+ feats = data['vertices'][..., 3:].to(device) if self.inference else data['points_cloud'][..., 3:].to(device)
+
+ if downsample and not self.inference:
+ # random sample
+ idx = torch.randperm(pc.shape[1])[:downsample].to(device)
+ pc = pc[:, idx]
+ feats = feats[:, idx]
+
+ # Embed the input data
+ co_embeds = self.fourier_embedder(pc)
+ if not self.separate:
+ co_embeds = self.co_proj(co_embeds)
+
+ if self.use_bone:
+ # joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints*2:2, :3]), self.fourier_embedder(joints[:,1:max_joints*2:2, :3])), dim=-1)
+ joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints,:3]), self.fourier_embedder(joints[:,:max_joints, 3:])), dim=-1)
+ else:
+ joints_fourier = self.fourier_embedder(joints[:,:max_joints, :3])
+
+ if not self.separate:
+ joints_embeds = self.co_proj(joints_fourier) if not self.use_bone else self.bone_proj(joints_fourier)
+
+ normal_embeds = self.normal_proj(self.normal_embedder(feats)) if not self.separate else self.normal_embedder(feats)
+
+ if not self.separate:
+ pc_embeds = co_embeds + normal_embeds
+ else:
+ joints_embeds = self.co_proj(joints_fourier.to(dtype)) if not self.use_bone else self.bone_proj(joints_fourier.to(dtype))
+ pc_embeds = self.pc_proj(torch.cat([co_embeds.to(dtype), normal_embeds.to(dtype)], dim=-1))
+
+ pc_num = pc_embeds.shape[-2]
+ joints_num = joints_embeds.shape[-2]
+ x = torch.cat([pc_embeds, joints_embeds], dim=-2)
+ for i, layer in enumerate(self.cross_attn):
+
+ x = layer(x, latents)
+ if self.use_mask:
+ x = self.cross_attn_joint[i](x, x[:, pc_num:], mask=mask.to(device))
+ else:
+ x = self.cross_attn_joint[i](x, x[:, pc_num:])
+ pc_embeds, joints_embeds = x.split([pc_num, joints_num], dim=1)
+
+ logits = torch.einsum('bnc,bmc->bnm', self.k_proj(self.ln_1(pc_embeds)), self.q_proj(self.ln_2(joints_embeds))) # (b, n, m)
+
+ if self.use_mask:
+ logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
+
+ if downsample and not self.inference:
+ return logits, idx
+
+ return logits
+
+class AttendjointsDecoder_multi(nn.Module):
+ def __init__(self,
+ # num_latents = 64,
+ # num_kv_latents = 257,
+ # out_channels = 3,
+ width = 768,
+ layers = 4,
+ device = 'cuda',
+ dtype = torch.float32,
+ heads = 12,
+ init_scale: float = 0.25,
+ flash = False,
+ use_checkpoint = False,
+ qkv_bias = False,
+ num_freqs: int = 8,
+ concat_num: int = 512,
+ include_pi: bool = True,
+ separate = False,
+ use_mask = True,
+ inference_with_repeat=False,
+ use_bone = True,
+ inference = False):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+ self.use_mask = use_mask
+ self.inference_with_repeat = inference_with_repeat
+ self.inference = inference
+
+ self.self_attn = Transformer(
+ device=device,
+ dtype=dtype,
+ n_ctx=-1,
+ width=width,
+ layers=layers,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_checkpoint=False,
+
+ )
+ self.concat_number = concat_num
+ self.separate = separate
+ self.normal_embedder = components_from_spherical_harmonics
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+ self.bone_proj = None if not use_bone else nn.Linear(self.fourier_embedder.out_dim * 2, width, device=device, dtype=dtype)
+ self.use_bone = use_bone
+ if not self.separate:
+ self.co_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
+ self.normal_proj = nn.Linear(25, width, device=device, dtype=dtype)
+ else:
+ self.pc_proj = nn.Linear(self.fourier_embedder.out_dim + 25, width, device=device, dtype=dtype)
+
+ # self.proj_attn = nn.Linear(width, width, device=device, dtype=dtype)
+
+ # self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
+ self.output_proj_joints = nn.Linear(width, width, device=device, dtype=dtype)
+ self.output_proj_points = nn.Linear(width, width, device=device, dtype=dtype)
+ self.layer_norm = nn.LayerNorm(width)
+
+ # def inference(self, latents, data=None,device='cuda', dtype='float32', use_mask=False):
+ def inference_mode(self):
+ self.inference = True
+
+ def forward(self, latents, data=None,device='cuda', downsample=None, dtype='float32'):
+ joints = data['bones'].to(device) if self.use_bone else data['joints'].to(device)
+ max_joints = max(data['bones_num']) if self.use_bone else max(data['joints_num'])
+
+ pc = data['points_cloud'][..., 0:3].to(device)
+ feats = data['points_cloud'][..., 3:].to(device)
+
+ if downsample:
+ # random sample
+ idx = torch.randperm(pc.shape[1])[:downsample].to(device)
+ pc = pc[:, idx]
+ feats = feats[:, idx]
+
+ bs = pc.shape[1]//self.concat_number
+
+ # Embed the input data
+ if self.use_bone:
+ # joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints*2:2, :3]), self.fourier_embedder(joints[:,1:max_joints*2:2, :3])), dim=-1)
+ joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints,:3]), self.fourier_embedder(joints[:,:max_joints, 3:])), dim=-1)
+ else:
+ joints_fourier = self.fourier_embedder(joints[:,:max_joints, :3])
+
+ if self.separate:
+ joints_embeds = self.co_proj(joints_fourier.to(dtype)) if not self.use_bone else self.bone_proj(joints_fourier.to(dtype))
+ points_embeds = self.fourier_embedder(pc)
+ normal_embeds = self.normal_embedder(feats)
+ points = self.pc_proj(torch.cat([points_embeds, normal_embeds], dim=-1))
+ else:
+ joints_embeds = self.co_proj(joints_fourier) if not self.use_bone else self.bone_proj(joints_fourier)
+ co_embeds = self.fourier_embedder(pc)
+ co_embeds = self.co_proj(co_embeds)
+ # Embed the normals
+ normal_embeds = self.normal_embedder(feats)
+ normal_embeds = self.normal_proj(normal_embeds) # (b, n, c)
+ points = (co_embeds + normal_embeds)
+
+ repeated_latents = repeat(latents, "b m c -> b n m c", n=bs)
+ repeated_joints = repeat(joints_embeds, "b m c -> b n m c", n=bs)
+ points = points.reshape( latents.shape[0], bs, self.concat_number, -1)
+
+ # Concatenate embeddings
+ x = torch.cat([repeated_joints, points, repeated_latents], dim=-2) # (b, bs, concat_number+latent_num+joints_num, c)
+
+ # Pass through self-attention
+ if self.use_mask:
+ mask = data['bones_mask'].to(device)
+ append_size = x.shape[2]-mask.shape[1] # the zero needs to append after mask
+ batch_size = mask.shape[0]
+ mask_extend = torch.ones((batch_size,append_size)).to(device)
+ mask = torch.cat([mask,mask_extend],dim=-1).repeat(bs,1).to(device)
+ x = rearrange(x, "b n m c -> (b n) m c")
+ x = self.self_attn(x,mask)
+ else:
+ x = rearrange(x, "b n m c -> (b n) m c")
+ x = self.self_attn(x)
+ joints, points, _ = x.split([joints_embeds.shape[1],self.concat_number, latents.shape[1]], dim=1)
+ joints = self.output_proj_joints(self.layer_norm(joints))
+ points = self.output_proj_points(self.layer_norm(points))
+
+ logits = torch.einsum('bik,bjk->bij', points, joints)
+ logits = rearrange(logits, '(b n) m c -> b (n m) c', b=pc.shape[0],n=bs) # (b, n, c)
+
+ if self.use_mask:
+ mask = data['bones_mask'].to(device)
+ logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
+
+ if self.inference:
+ vertices = data['vertice']
+ points_cloud = data['points_cloud'][0,..., 0:3].to(device)
+ vertices_exp = vertices[0,...,:3] # (batch_size, num_vertices, 1, 3)
+ logits = compute_nearest_points(vertices_exp, points_cloud, logits[0], device)
+
+ if downsample:
+ return logits, idx
+
+ return logits
+
+def compute_nearest_points(vertices, points, logits, device, batch_size=1024):
+ # vertices: [N, 3]
+ # points: [M, 3]
+ # logits: [M, K] (K is the number of skinning weights)
+
+ num_vertices = vertices.shape[0]
+ # Initialize the output tensor for skinning weights
+ skin_predict = torch.zeros((num_vertices, logits.shape[1]), device=device)
+
+ # Split vertices into batches
+ for i in range(0, num_vertices, batch_size):
+
+ batch_vertices = vertices[i:i+batch_size] # [batch_size, 3]
+ vertices_exp = batch_vertices.unsqueeze(1) # [batch_size, 1, 3]
+ points_exp = points.unsqueeze(0) # [1, num_points, 3]
+ distances = torch.sum((vertices_exp - points_exp) ** 2, dim=-1) # [batch_size, num_points]
+ nearest_idx = torch.argmin(distances, dim=-1) # [batch_size]
+ skin_predict_batch = logits[nearest_idx] # [batch_size, K]
+ skin_predict[i:i+batch_size] = skin_predict_batch
+
+ return skin_predict
\ No newline at end of file
diff --git a/Anymate/tmp/.gitkeep b/Anymate/tmp/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Anymate/utils/dataset_utils.py b/Anymate/utils/dataset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..19563ccf9b1ee3a3991d60811c073f43b952d3bc
--- /dev/null
+++ b/Anymate/utils/dataset_utils.py
@@ -0,0 +1,129 @@
+import numpy as np
+import torch
+import trimesh
+from ThirdParty.Rignet_utils import binvox_rw
+
+
+def sparse_to_index(sparse_matrix):
+ index = []
+ weight = []
+ for j in range(len(sparse_matrix)):
+ if sparse_matrix[j] > 0:
+ index.append(j)
+ weight.append(sparse_matrix[j])
+
+ return index, weight
+
+def index_to_sparse(index, weight, shape):
+ sparse_matrix = np.zeros([shape[0], shape[1], shape[2]+1])
+
+ row_indices, col_indices = np.meshgrid(np.arange(sparse_matrix.shape[0]), np.arange(sparse_matrix.shape[1]), indexing='ij')
+
+ row_indices = np.expand_dims(row_indices, axis=-1)
+ col_indices = np.expand_dims(col_indices, axis=-1)
+
+ sparse_matrix[row_indices, col_indices, index] = weight
+
+
+ return torch.from_numpy(sparse_matrix[:, :, :-1])
+
+def index_to_sparse_con(index, shape):
+
+ sparse_matrix = np.zeros([shape[0], shape[1], shape[2]+1],dtype=np.int8)
+ row_indices, col_indices = np.meshgrid(np.arange(sparse_matrix.shape[0]), np.arange(sparse_matrix.shape[1]), indexing='ij')
+
+ row_indices = np.expand_dims(row_indices, axis=-1)
+ col_indices = np.expand_dims(col_indices, axis=-1)
+
+ sparse_matrix[row_indices, col_indices, index] = 1
+
+
+ return torch.from_numpy(sparse_matrix[:, :, :-1])
+
+def create_mask(n, max_len=64):
+ mask = torch.zeros(max_len, dtype=torch.bool)
+ mask[:n] = 1
+ return mask
+
+def reduce(vox):
+ new_data = np.zeros((vox.dims[0] // 2, vox.dims[1] // 2, vox.dims[2] // 2)).astype(bool)
+ new_data = np.logical_or(new_data, vox.data[::2, ::2, ::2])
+ new_data = np.logical_or(new_data, vox.data[1::2, ::2, ::2])
+ new_data = np.logical_or(new_data, vox.data[::2, 1::2, ::2])
+ new_data = np.logical_or(new_data, vox.data[::2, ::2, 1::2])
+ new_data = np.logical_or(new_data, vox.data[1::2, 1::2, ::2])
+ new_data = np.logical_or(new_data, vox.data[1::2, ::2, 1::2])
+ new_data = np.logical_or(new_data, vox.data[::2, 1::2, 1::2])
+ new_data = np.logical_or(new_data, vox.data[1::2, 1::2, 1::2])
+ # dilate the new voxel
+ new_data[:-1, :, :] = np.logical_or(new_data[:-1, :, :], new_data[1:, :, :])
+ new_data[:, :-1, :] = np.logical_or(new_data[:, :-1, :], new_data[:, 1:, :])
+ new_data[:, :, :-1] = np.logical_or(new_data[:, :, :-1], new_data[:, :, 1:])
+ return binvox_rw.Voxels(new_data, new_data.shape, vox.translate, vox.scale, vox.axis_order)
+
+def align(vox, y_max):
+ new_data = np.zeros(vox.dims).astype(bool)
+ ind = np.argwhere(vox.data)
+ ind = ind + (np.array(vox.translate) - np.array([-0.5, -0.5 * (1 - y_max), -0.5])) * vox.dims[0]
+ # round to the nearest integer
+ # ind = np.round(ind).astype(int)
+ ind = np.ceil(ind).astype(int)
+ # clip to the valid range
+ ind = np.clip(ind, 0, vox.dims[0] - 1)
+ # new_data[ind[:, 0], ind[:, 1], ind[:, 2]] = True
+ return ind
+
+def get_skin_direction(joint_idx, data, parent_index, joints_matrix):
+ # Get points influenced by this joint (weight > 0)
+ weights = index_to_sparse(data['skins_index'].unsqueeze(0), data['skins_weight'].unsqueeze(0), [1, 8192, data['bones_num']])[0][:,joint_idx]
+ mask = weights > 0
+
+ if not torch.any(mask):
+ # If no points are influenced, return the opposite direction of its parent
+ parent_idx = parent_index[joint_idx].item()
+ if parent_idx == joint_idx:
+ return torch.tensor([0, 0, 0.001])
+ parent_pos = joints_matrix[parent_idx, :3]
+ joint_pos = joints_matrix[joint_idx, :3]
+ direction = joint_pos - parent_pos
+ norm = torch.norm(direction)
+ if norm < 1e-8: # Add check for zero norm
+ return torch.tensor([0, 0, 0.001])
+ normalized_direction = direction / norm
+ return normalized_direction * 0.01
+
+ # Get joint position
+ joint_pos = joints_matrix[joint_idx, :3]
+
+ # Get weighted average direction from joint to influenced points
+ points = data['pc'][mask][:,:3]
+ point_weights = weights[mask]
+
+ # Calculate directions from joint to each point
+ directions = points - joint_pos
+
+ # Calculate weighted average direction
+ avg_direction = torch.sum(directions * point_weights.unsqueeze(1), dim=0) / torch.sum(point_weights)
+ if torch.norm(avg_direction) < 1e-5:
+ return torch.tensor([0, 0, 0.001])
+ return avg_direction * 1.25
+
+def obj2mesh(obj_path):
+ # open the obj as txt
+ vertices = []
+ faces = []
+ with open(obj_path, 'r') as f:
+ obj = f.readlines()
+ for line in obj:
+ if line.startswith('v '):
+ vertices.append(list(map(float, line.split()[1:])))
+ elif line.startswith('f '):
+ faces.append(list(map(int, [i.split('/')[0] for i in line.split()[1:]])))
+ vertices = np.array(vertices)
+ faces = np.array(faces) - 1
+ # print(vertices.shape, faces.shape)
+
+ # create trimesh mesh with given vertices and faces
+ mesh = trimesh.Trimesh(vertices, faces, process=False)
+ # print(mesh.vertices.shape, mesh.faces.shape)
+ return mesh
\ No newline at end of file
diff --git a/Anymate/utils/diffusion_encoder.py b/Anymate/utils/diffusion_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee0060d314cc64398ebe0aa28b9b671d4ef1e396
--- /dev/null
+++ b/Anymate/utils/diffusion_encoder.py
@@ -0,0 +1,258 @@
+import torch
+import torch.nn as nn
+from typing import Optional
+from einops import repeat
+import math
+from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock,Transformer, checkpoint
+from torch.nn import Sequential, Dropout, Linear, ReLU, Parameter, BatchNorm1d
+from typing import List, Optional, Tuple, Union
+
+class ShapeAsLatentModule(nn.Module):
+ latent_shape: Tuple[int, int]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def decode(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def query_geometry(self, *args, **kwargs):
+ raise NotImplementedError
+
+class FourierEmbedder(nn.Module):
+
+ def __init__(self,
+ num_freqs: int = 6,
+ logspace: bool = True,
+ input_dim: int = 3,
+ include_input: bool = True,
+ include_pi: bool = True) -> None:
+
+ """The initialization"""
+
+ super().__init__()
+
+ if logspace:
+ frequencies = 2.0 ** torch.arange(
+ num_freqs,
+ dtype=torch.float32
+ )
+ else:
+ frequencies = torch.linspace(
+ 1.0,
+ 2.0 ** (num_freqs - 1),
+ num_freqs,
+ dtype=torch.float32
+ )
+
+ if include_pi:
+ frequencies *= torch.pi
+
+ self.register_buffer("frequencies", frequencies, persistent=False)
+ self.include_input = include_input
+ self.num_freqs = num_freqs
+
+ self.out_dim = self.get_dims(input_dim)
+
+ def get_dims(self, input_dim):
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
+
+ return out_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+
+ if self.num_freqs > 0:
+ self.frequencies = self.frequencies.to(x.device)
+ embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
+
+ if self.include_input:
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
+ else:
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
+ else:
+ return x
+
+def MLP(channels, batch_norm=True):
+ if batch_norm:
+ return Sequential(*[Sequential(Linear(channels[i - 1], channels[i]), ReLU(), BatchNorm1d(channels[i], momentum=0.1))
+ for i in range(1, len(channels))])
+ else:
+ return Sequential(*[Sequential(Linear(channels[i - 1], channels[i]), ReLU()) for i in range(1, len(channels))])
+
+class CrossAttentionEncoder(nn.Module):
+
+ def __init__(self, *,
+ device: Optional[torch.device],
+ dtype: Optional[torch.dtype],
+ num_latents: int,
+ fourier_embedder: FourierEmbedder,
+ point_feats: int,
+ width: int,
+ heads: int,
+ layers: int,
+ init_scale: float = 0.25,
+ qkv_bias: bool = True,
+ flash: bool = False,
+ use_ln_post: bool = False,
+ use_checkpoint: bool = False):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+ self.num_latents = num_latents
+ self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
+
+ self.fourier_embedder = fourier_embedder
+ self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
+ self.cross_attn = ResidualCrossAttentionBlock(
+ device=device,
+ dtype=dtype,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ )
+
+ self.self_attn = Transformer(
+ device=device,
+ dtype=dtype,
+ n_ctx=num_latents,
+ width=width,
+ layers=layers,
+ heads=heads,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_checkpoint=False
+ )
+
+ if use_ln_post:
+ self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
+ else:
+ self.ln_post = None
+
+ def _forward(self, pc, feats):
+ """
+
+ Args:
+ pc (torch.FloatTensor): [B, N, 3]
+ feats (torch.FloatTensor or None): [B, N, C]
+
+ Returns:
+
+ """
+
+ bs = pc.shape[0]
+
+ data = self.fourier_embedder(pc)
+ if feats is not None:
+ data = torch.cat([data, feats], dim=-1)
+ data = self.input_proj(data)
+
+ query = repeat(self.query, "m c -> b m c", b=bs)
+ latents = self.cross_attn(query, data)
+ latents = self.self_attn(latents)
+
+ if self.ln_post is not None:
+ latents = self.ln_post(latents)
+
+ return latents, pc
+
+ def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
+ """
+
+ Args:
+ pc (torch.FloatTensor): [B, N, 3]
+ feats (torch.FloatTensor or None): [B, N, C]
+
+ Returns:
+ dict
+ """
+
+ return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
+
+
+
+class TransformerEncoder(ShapeAsLatentModule):
+ def __init__(self, *,
+ device: Optional[torch.device]='cuda',
+ dtype: Optional[torch.dtype],
+ num_latents: int = 16,
+ point_feats: int = 3,
+ embed_dim: int = 64,
+ num_freqs: int = 8,
+ include_pi: bool = True,
+ width: int = 768,
+ heads: int = 12,
+ num_encoder_layers: int = 8,
+ init_scale: float = 0.25,
+ qkv_bias: bool = True,
+ flash: bool = False,
+ use_ln_post: bool = False,
+ use_checkpoint: bool = False,
+ out_channels: int = 4):
+
+ super().__init__()
+
+ self.use_checkpoint = use_checkpoint
+
+ self.num_latents = num_latents
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+
+ init_scale = init_scale * math.sqrt(1.0 / width)
+ self.encoder = CrossAttentionEncoder(
+ device=device,
+ dtype=dtype,
+ fourier_embedder=self.fourier_embedder,
+ num_latents=num_latents,
+ point_feats=point_feats,
+ width=width,
+ heads=heads,
+ layers=num_encoder_layers,
+ init_scale=init_scale,
+ qkv_bias=qkv_bias,
+ flash=flash,
+ use_ln_post=use_ln_post,
+ use_checkpoint=use_checkpoint
+ )
+ self.width = width
+ self.out_channels = out_channels
+ self.device = device
+
+ self.embed_dim = embed_dim
+
+ def encode(self,data):
+ input_points = data['points_cloud'].to(self.device)
+ bs = input_points.shape[0]
+ pc, feats = input_points[...,:3], input_points[..., 3:]
+ latents, _ = self.encoder(pc, feats)
+ # print_time('after encoder')
+ latents = latents.reshape(bs,-1, self.width)
+ return latents
+ def encode_pc(self,points_cloud):
+ bs = points_cloud.shape[0]
+ input_points = points_cloud.to(self.device)
+ pc, feats = input_points[...,:3], input_points[..., 3:]
+ latents, _ = self.encoder(pc, feats)
+
+ latents = latents.reshape(bs,-1, self.width)
+ return latents
+ def forward(self, data):
+
+ # input_points = torch.from_numpy(np.array(data.points_cloud)).cuda()
+ input_points = data['points_cloud'].to(self.device)
+ pc, feats = input_points[...,:3], input_points[..., 3:]
+ latents, _ = self.encoder(pc, feats)
+
+ latents = latents.reshape(-1, self.width)
+ latents =latents.reshape(-1, self.num_latents, self.out_channels)
+ latents[..., :3] = torch.tanh(latents[..., :3])
+ latents[..., 3:] = torch.sigmoid(latents[..., 3:])
+
+
+ return latents
\ No newline at end of file
diff --git a/Anymate/utils/diffusion_utils.py b/Anymate/utils/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c171038ede38173542450c79feb5b9c375f81f4e
--- /dev/null
+++ b/Anymate/utils/diffusion_utils.py
@@ -0,0 +1,314 @@
+
+import numpy as np
+import matplotlib.pyplot as plt
+from mpl_toolkits.mplot3d import Axes3D
+from torchvision.utils import make_grid
+import torch
+from typing import List, Optional, Tuple, Union
+import torch.nn as nn
+import math
+from timm.models.vision_transformer import Mlp, DropPath
+
+def my_collate_diff(batch,return_joints_num=128,random=False):
+ data = {}
+ for key in batch[0]:
+ if key=='vox' or key=='name' or key=='joints_num' or key=='skins_index' or key=='skins_weight' or key=='parent_index' or key=='conns' or key=='joints' or key=='bones' or key=='mesh_skins_index' or key=='mesh_skins_weight' or key=='mesh_pc' or key=='mesh_face':
+ data[key] = [sample[key] for sample in batch]
+ elif key=='pc':
+ data['points_cloud'] = torch.stack([sample['pc'] for sample in batch])
+ elif key=='skins':
+ continue
+ elif key=='bones_num':
+ data[key] = torch.tensor([sample['bones_num'] for sample in batch])
+ else:
+ data[key] = torch.stack([sample[key] for sample in batch])
+
+ if 'joints' in batch[0]:
+ padded_joints_matrix = torch.ones(len(data['name']), return_joints_num, 3) * (-3)
+ joints_matrix = torch.ones(len(data['name']), 96, 3) * (-3)
+ for i in range(len(data['name'])):
+ joints_matrix[i, :data['joints_num'][i], :] = data['joints'][i]
+ if not random:
+ for i in range(len(data['name'])):
+ padded_joints_matrix[i] = data['joints'][i].repeat(return_joints_num//data['joints_num'][i]+1,1)[:return_joints_num,:]
+ else:
+ for i in range(len(data['name'])):
+ padded_joints_matrix[i] = data['joints'][i][torch.randint(0, data['joints_num'][i], (return_joints_num,))]
+ data['joints_repeat'] = padded_joints_matrix
+ data['joints'] = joints_matrix
+
+ return data
+
+def randn_tensor(
+ shape: Union[Tuple, List],
+ generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
+ device: Optional["torch.device"] = None,
+ dtype: Optional["torch.dtype"] = None,
+ layout: Optional["torch.layout"] = None,
+):
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
+ is always created on the CPU.
+ """
+ # device on which tensor is created defaults to device
+ rand_device = device
+ batch_size = shape[0]
+
+ layout = layout or torch.strided
+ device = device or torch.device("cpu")
+
+ if generator is not None:
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
+ if gen_device_type != device.type and gen_device_type == "cpu":
+ rand_device = "cpu"
+
+ elif gen_device_type != device.type and gen_device_type == "cuda":
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
+
+ # make sure generator list of length 1 is treated like a non-list
+ if isinstance(generator, list) and len(generator) == 1:
+ generator = generator[0]
+
+ if isinstance(generator, list):
+ shape = (1,) + shape[1:]
+ latents = [
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
+ for i in range(batch_size)
+ ]
+ latents = torch.cat(latents, dim=0).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
+
+ return latents
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ kv_dim=None,
+ num_heads=16,
+ qkv_bias=False,
+ attn_drop=0.,
+ proj_drop=0.,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ kv_dim = dim if not kv_dim else kv_dim
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
+ self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
+ self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
+ self.attn_drop_rate = attn_drop
+ self.attn_drop = nn.Dropout(self.attn_drop_rate)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x_q, x_kv):
+ B, N_q, C = x_q.shape
+ B, N_kv, _ = x_kv.shape
+ # [B, N_q, C] -> [B, N_q, H, C/H] -> [B, H, N_q, C/H]
+ q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ # [B, N_kv, C] -> [B, N_kv, H, C/H] -> [B, H, N_kv, C/H]
+ k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ # [B, N_kv, C] -> [B, N_kv, H, C/H] -> [B, H, N_kv, C/H]
+ v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+
+ # [B, H, N_q, C/H] @ [B, H, C/H, N_kv] -> [B, H, N_q, N_kv]
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ # [B, H, N_q, N_kv] @ [B, H, N_kv, C/H] -> [B, H, N_q, C/H]
+ x = attn @ v
+
+ # [B, H, N_q, C/H] -> [B, N_q, C]
+ x = x.transpose(1, 2).reshape(B, N_q, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class Compute_Block(nn.Module):
+
+ def __init__(self, z_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm_z1 = norm_layer(z_dim)
+ self.attn = CrossAttention(
+ z_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm_z2 = norm_layer(z_dim)
+ mlp_hidden_dim = int(z_dim * mlp_ratio)
+ self.mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, z):
+ zn = self.norm_z1(z)
+ z = z + self.drop_path(self.attn(zn, zn))
+ z = z + self.drop_path(self.mlp(self.norm_z2(z)))
+ return z
+
+class Read_Block(nn.Module):
+
+ def __init__(self, z_dim, x_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm_x = norm_layer(x_dim)
+ self.norm_z1 = norm_layer(z_dim)
+ self.attn = CrossAttention(
+ z_dim, x_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm_z2 = norm_layer(z_dim)
+ mlp_hidden_dim = int(z_dim * mlp_ratio)
+ self.mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, z, x):
+ z = z + self.drop_path(self.attn(self.norm_z1(z), self.norm_x(x)))
+ z = z + self.drop_path(self.mlp(self.norm_z2(z)))
+ return z
+
+class Write_Block(nn.Module):
+
+ def __init__(self, z_dim, x_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm_z = norm_layer(z_dim)
+ self.norm_x1 = norm_layer(x_dim)
+ self.attn = CrossAttention(
+ x_dim, z_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm_x2 = norm_layer(x_dim)
+ mlp_hidden_dim = int(x_dim * mlp_ratio)
+ self.mlp = Mlp(in_features=x_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, z, x):
+ x = x + self.drop_path(self.attn(self.norm_x1(x), self.norm_z(z)))
+ x = x + self.drop_path(self.mlp(self.norm_x2(x)))
+ return x
+
+class RCW_Block(nn.Module):
+
+ def __init__(self, z_dim, x_dim, num_compute_layers=4, num_heads=16,
+ mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.read = Read_Block(z_dim, x_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop,
+ attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer)
+ self.write = Write_Block(z_dim, x_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop,
+ attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer)
+ self.compute = nn.ModuleList([
+ Compute_Block(z_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop,
+ attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer)
+ for _ in range(num_compute_layers)
+ ])
+
+ def forward(self, z, x):
+ z = self.read(z, x)
+ for layer in self.compute:
+ z = layer(z)
+ x = self.write(z, x)
+ return z, x
+
+def pairwise_distances(x, y):
+ #Input: x is a Nxd matrix
+ # y is an optional Mxd matirx
+ #Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
+ # if y is not given then use 'y=x'.
+ #i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
+ x_norm = (x ** 2).sum(1).view(-1, 1)
+ y_t = torch.transpose(y, 0, 1)
+ y_norm = (y ** 2).sum(1).view(1, -1)
+ dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
+ return torch.clamp(dist, 0.0, np.inf)
+
+def meanshift_cluster(pts_in, bandwidth, weights=None, max_iter=20):
+ """
+ Meanshift clustering
+ :param pts_in: input points
+ :param bandwidth: bandwidth
+ :param weights: weights per pts indicting its importance in the clustering
+ :return: points after clustering
+ """
+ diff = 1e10
+ num_iter = 1
+ while diff > 1e-3 and num_iter < max_iter:
+ Y = np.sum(((pts_in[np.newaxis, ...] - pts_in[:, np.newaxis, :]) ** 2), axis=2)
+ K = np.maximum(bandwidth**2 - Y, np.zeros(Y.shape))
+ if weights is not None:
+ K = K * weights
+ row_sums = K.sum(axis=0, keepdims=True)
+ P = K / (row_sums + 1e-10)
+ P = P.transpose()
+ pts_in_prim = 0.3 * (np.matmul(P, pts_in) - pts_in) + pts_in
+ diff = np.sqrt(np.sum((pts_in_prim - pts_in)**2))
+ pts_in = pts_in_prim
+ num_iter += 1
+ return pts_in
+
+def nms_meanshift(pts_in, density, bandwidth):
+ """
+ NMS to extract modes after meanshift. Code refers to sci-kit-learn.
+ :param pts_in: input points
+ :param density: density at each point
+ :param bandwidth: bandwidth used in meanshift. Used here as neighbor region for NMS
+ :return: extracted clusters.
+ """
+ Y = np.sum(((pts_in[np.newaxis, ...] - pts_in[:, np.newaxis, :]) ** 2), axis=2)
+ sorted_ids = np.argsort(density)[::-1]
+ unique = np.ones(len(sorted_ids), dtype=bool)
+ dist = np.sqrt(Y)
+ for i in sorted_ids:
+ if unique[i]:
+ neighbor_idxs = np.argwhere(dist[:, i] <= bandwidth)
+ unique[neighbor_idxs.squeeze()] = 0
+ unique[i] = 1 # leave the current point as unique
+ pts_in = pts_in[unique]
+ return pts_in
+
+def get_predictions(y_pred_np, attn_pred_np=None,bandwidth=0.05, threshold=0.001):
+ """
+ get the final predictions
+ :param pts: input points
+ :param weights: weight per point during clustering
+ :return: clustered points
+ """
+ # if attn_pred_np is None:
+ # attn_pred_np = np.ones(y_pred_np.shape[0])
+ y_pred_np = meanshift_cluster(y_pred_np, bandwidth, attn_pred_np, max_iter=40)
+
+
+ Y_dist = np.sum(((y_pred_np[np.newaxis, ...] - y_pred_np[:, np.newaxis, :]) ** 2), axis=2)
+ density = np.maximum(bandwidth ** 2 - Y_dist, np.zeros(Y_dist.shape))
+ density = np.sum(density, axis=0)
+ density_sum = np.sum(density)
+ y_pred_np = y_pred_np[density / density_sum > threshold]
+
+ density = density[density / density_sum > threshold]
+ pred_joints = nms_meanshift(y_pred_np, density, bandwidth)
+ return pred_joints
+
+
+if __name__ == '__main__':
+ points_cloud = np.ones((100, 3))
+ predict_out = get_predictions(points_cloud, bandwidth=0.05, threshold=0.001)
+ print(predict_out.shape)
+
\ No newline at end of file
diff --git a/Anymate/utils/eval_utils.py b/Anymate/utils/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58622c660cdc2482d924197d65019c94816cd51
--- /dev/null
+++ b/Anymate/utils/eval_utils.py
@@ -0,0 +1,225 @@
+from tqdm import tqdm
+import torch
+import torch.nn.functional as F
+import numpy as np
+import point_cloud_utils as pcu
+from Anymate.utils.loss_utils import chamfer_distance_with_average, cross_entropy_with_probs_batch, cos_loss, cos_loss_clamp
+from ThirdParty.Rignet_utils.utils import get_skel
+from ThirdParty.Rignet_utils.Rignet_loss import edit_dist, chamfer_dist, joint2bone_chamfer_dist, bone2bone_chamfer_dist
+from scipy.optimize import linear_sum_assignment
+
+def evaluate_joint(joints, joints_gt, threshold=1e-1):
+ """
+ joints: list of predicted joints: tensor of shape (n,joints_num,3)
+ joints_gt: list of ground truth joints : tensor of shape (n,joints_num,3)
+ """
+ chamfer_loss_all = 0
+ emd_loss_all = 0
+ precision = 0
+ recall = 0
+ count = 0
+
+ for i in tqdm(range(len(joints))):
+ joint_predict = joints[i].cpu()
+ joint_gt = joints_gt[i].cpu()
+ distance_matrix = torch.cdist(joint_gt, joint_predict) # (n_gt, n_predict)
+ n_gt,n_predict = distance_matrix.shape
+ min_distance_pred = torch.min(distance_matrix, dim=0)
+ min_distance_gt = torch.min(distance_matrix, dim=1)
+ precision += torch.sum(min_distance_pred.values < threshold).item()/n_predict
+ recall += torch.sum(min_distance_gt.values < threshold).item()/n_gt
+
+ chamfer_loss_all += chamfer_distance_with_average(joint_predict.unsqueeze(0), joint_gt.unsqueeze(0))
+ joint_predict = joint_predict.numpy().astype(np.float64)
+ joint_gt = joint_gt.numpy().astype(np.float64)
+ emd,_ = pcu.earth_movers_distance(joint_predict, joint_gt)
+ emd_loss_all += emd
+
+ count += 1
+
+ print('------------------------------------')
+ print('Evaluation results for joint:')
+ print('chamfer_loss:', chamfer_loss_all/count)
+ print('emd_loss:', emd_loss_all/count)
+ print('precision:', precision/count)
+ print('recall:', recall/count)
+ print('count:', count)
+ print('------------------------------------')
+ return chamfer_loss_all/count, emd_loss_all/count, precision/count, recall/count
+
+def evaluate_connectivity(conns, conns_gt, joints_gt, vox_list):
+
+ """
+ conns: list of predicted connections probability: tensor of shape (n,joints_num,joints_num)
+ conns_gt: list of ground truth connections: tensor of shape (n,joints_num,joints_num)
+ """
+
+ precision_all = 0
+ recall_all = 0
+ cross_entropy_all = 0
+ bone2bone_dist_con = 0
+ count = 0
+ for i in tqdm(range(len(conns))):
+
+ conn_predict = conns[i].cpu().numpy()
+ conn_gt = conns_gt[i].cpu().numpy()
+ joints = joints_gt[i].cpu().numpy()
+ vox = vox_list[i]
+
+ cross_entropy_all += cross_entropy_with_probs_batch(torch.from_numpy(conn_predict).unsqueeze(0), torch.from_numpy(conn_gt).unsqueeze(0), reduction='mean')
+ # consider to add tree edit distance (need joint and vox information)
+ pred_skel, parent_matrix = get_skel(joints, conn_predict, vox=vox)
+ gt_skel, parent_matrix = get_skel(joints, conn_gt, vox=vox)
+ bone2bone_dist_con += bone2bone_chamfer_dist(pred_skel, gt_skel)
+
+ conn_predict = np.argmax(conn_predict, axis=1)
+ conn_gt = np.argmax(conn_gt, axis=1)
+ connection_matrix_pre = torch.zeros((len(conn_predict),len(conn_predict)))
+ connection_matrix_gt = torch.zeros((len(conn_predict),len(conn_predict)))
+
+ for i in range(len(conn_predict)):
+ connection_matrix_pre[i][conn_predict[i]] = 1
+ connection_matrix_pre[conn_predict[i]][i] = 1
+ connection_matrix_gt[i][conn_gt[i]] = 1
+ connection_matrix_gt[conn_gt[i]][i] = 1
+
+ TP = 0
+ FP = 0
+ FN = 0
+ FP = 0
+
+ for i in range(len(conn_predict)):
+ if connection_matrix_gt[i][conn_predict[i]] == 1:
+ TP += 1
+ if connection_matrix_gt[i][conn_predict[i]] == 0:
+ FP += 1
+ if connection_matrix_pre[i][conn_gt[i]] == 0:
+ FN += 1
+
+ precision = TP/(TP+FP)
+ recall = TP/(TP+FN)
+
+ precision_all += precision
+ recall_all += recall
+ count+=1
+ print('------------------------------------')
+ print('Evaluation results for connectivity:')
+ print('precision:',precision_all/count)
+ print('recall:',recall_all/count)
+ print('cross_entropy:',cross_entropy_all/count)
+ print('bone2bone_dist_con:',bone2bone_dist_con/count)
+ print('count:',count)
+ print('------------------------------------')
+ return precision_all/count, recall_all/count
+
+def evaluate_skinning(skins, skins_gt, threshold=5e-2):
+ """
+ skins: list of predicted skinning weights: tensor of shape (n,vertices_num, bones_num)
+ skins_gt: list of ground truth skinning weights: tensor of shape (n,vertices_num, bones_num)
+ """
+ cs_loss = 0
+ ce_loss = 0
+ cs_loss_clamp = 0
+ count = 0
+ L1_loss = 0
+ precision = 0
+ recall = 0
+ mean_l1_dist = 0
+
+ for i in tqdm(range(len(skins))):
+ skin_predict = skins[i].cpu().unsqueeze(0)
+ skin_gt = skins_gt[i].cpu().unsqueeze(0)
+
+ precision_one = 0
+ recall_one = 0
+
+ ce_loss += cross_entropy_with_probs_batch(skin_predict, skin_gt)
+ cs_loss += cos_loss(skin_predict, skin_gt)
+ cs_loss_clamp += cos_loss_clamp(skin_predict, skin_gt)
+ L1_loss += F.l1_loss(skin_predict, skin_gt)
+ skin_predict = skin_predict[0].cpu().detach().numpy()
+ skin_gt = skin_gt[0].cpu().detach().numpy()
+ mean_l1_dist += np.sum(np.abs(skin_predict - skin_gt )) / len(skin_predict)
+
+ for i in range(len(skin_predict)):
+ influencial_bone_predict = skin_predict[i] >=threshold
+ influencial_bone_gt = skin_gt[i] >=threshold
+ influencial_bone_correct = influencial_bone_predict*influencial_bone_gt
+
+ if np.sum(influencial_bone_predict)==0 or np.sum(influencial_bone_gt)==0:
+ continue
+ precision_one += np.sum(influencial_bone_correct)/np.sum(influencial_bone_predict)
+ recall_one += np.sum(influencial_bone_correct)/np.sum(influencial_bone_gt)
+
+ precision += precision_one/len(skin_predict)
+ recall += recall_one/len(skin_predict)
+ count +=1
+
+ print('------------------------------------')
+ print('Evaluation results for skinning:')
+ print('cos loss: ', cs_loss/count)
+ print('ce loss: ', ce_loss/count)
+ print('cs_loss_clamp: ', cs_loss_clamp/count)
+ print('L1 loss: ', L1_loss/count)
+ print('mean_l1_dist: ', mean_l1_dist/count)
+ print('precision: ', precision/count)
+ print('recall: ', recall/count)
+ print('count: ', count)
+ print('------------------------------------')
+
+def evaluate_skeleton(joints,joints_gt,conns,conns_gt,vox_list,fs_threshold=0.2):
+
+ """
+ joints: list of predicted joints: tensor of shape (n,joints_num,3)
+ joints_gt: list of ground truth joints : tensor of shape (n,joints_num,3)
+ conns: list of predicted connections probability: tensor of shape (n,joints_num,joints_num)
+ conns_gt: list of ground truth connections: tensor of shape (n,joints_num,joints_num)
+ vox_list: list of voxel: (n,88,88,88)
+ """
+
+ data_count = 0
+ chamfer_score = 0
+ j2b_chamfer_joint = 0
+ bone2bone_dist_joint = 0
+ edit_distance_joint = 0
+ joint_IoU_total = 0
+ joint_precision_total = 0
+ joint_recall_total = 0
+
+ for i in tqdm(range(len(joints))):
+ joint_predict = joints[i].cpu().numpy()
+ joint_gt = joints_gt[i].cpu().numpy()
+ conn_predict = conns[i].cpu().numpy()
+ conn_gt = conns_gt[i].cpu().numpy()
+ vox = vox_list[i]
+
+ # add shape diameter after we have vertex and faces
+ # shape_diameter = get_shape_diameter(mesh, points, parent_index[:,0])
+
+ dist_matrix = np.sqrt(np.sum((joint_predict[np.newaxis, ...] - joint_gt[:, np.newaxis, :]) ** 2, axis=2))
+ row_ind, col_ind = linear_sum_assignment(dist_matrix)
+ # fs_threshold = shape_diameter[row_ind]
+ joint_IoU = 2 * np.sum(dist_matrix[row_ind, col_ind] < fs_threshold) / (len(joint_predict) + len(joint_gt))
+ joint_IoU_total += joint_IoU
+ joint_precision = np.sum(dist_matrix[row_ind, col_ind] < fs_threshold) / len(joint_predict)
+ joint_precision_total += joint_precision
+ joint_recall = np.sum(dist_matrix[row_ind, col_ind] < fs_threshold) / len(joint_gt)
+ joint_recall_total += joint_recall
+
+ pred_skel_joint,parent_matrix = get_skel(joint_predict,conn_predict,vox=vox)
+ gt_skel, parent_matrix = get_skel(joint_gt,conn_gt,vox=vox)
+ chamfer_score += chamfer_dist(joint_predict, joint_gt)
+ j2b_chamfer_joint += joint2bone_chamfer_dist(pred_skel_joint, gt_skel)
+ bone2bone_dist_joint += bone2bone_chamfer_dist(pred_skel_joint, gt_skel)
+ edit_distance_joint += edit_dist(pred_skel_joint, gt_skel)
+ data_count+=1
+
+ print('------------------------------------')
+ print('Evaluation results for skeleton:')
+ print('chamfer_score:', chamfer_score/data_count)
+ print('j2b_chamfer_joint:', j2b_chamfer_joint/data_count)
+ print('bone2bone_dist_joint:', bone2bone_dist_joint/data_count)
+ print('joint_IoU:', joint_IoU_total/data_count)
+ print('joint_precision:', joint_precision_total/data_count)
+ print('joint_recall:', joint_recall_total/data_count)
+ print('------------------------------------')
\ No newline at end of file
diff --git a/Anymate/utils/loss_utils.py b/Anymate/utils/loss_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6acff96ada6c34426405f6b3e19ce080947d30db
--- /dev/null
+++ b/Anymate/utils/loss_utils.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+def chamfer_distance_with_average(p1, p2):
+
+ '''
+ Calculate Chamfer Distance between two point sets
+ :param p1: size[1, N, D]
+ :param p2: size[1, M, D]
+ :param debug: whether need to output debug info
+ :return: sum of Chamfer Distance of two point sets
+ '''
+
+ assert p1.size(0) == 1 and p2.size(0) == 1
+ assert p1.size(2) == p2.size(2)
+ p1 = p1.repeat(p2.size(1), 1, 1)
+ p1 = p1.transpose(0, 1)
+ p2 = p2.repeat(p1.size(0), 1, 1)
+ dist = torch.add(p1, torch.neg(p2))
+ dist_norm = torch.norm(dist, 2, dim=2)
+ dist1 = torch.min(dist_norm, dim=1)[0]
+ dist2 = torch.min(dist_norm, dim=0)[0]
+ loss = 0.5 * ((torch.mean(dist1)) + (torch.mean(dist2)))
+ return loss
+
+def cross_entropy_with_probs_batch(input, target, weight=None, reduction="mean"): # tested, same as nn.CrossEntropyLoss at dim=1, CE can be negative
+ # input_logsoftmax = F.log_softmax(input, dim=2)
+ input_logsoftmax = torch.log(input+1e-6)
+ cum_losses = -target * input_logsoftmax
+ if weight is not None:
+ cum_losses = cum_losses * weight.unsqueeze(1) # Broadcasting the weight
+
+ if reduction == "none":
+ return cum_losses
+ elif reduction == "mean":
+ return cum_losses.sum(dim=2).mean(dim=1).mean(dim=0)
+ elif reduction == "sum":
+ return cum_losses.sum(dim=2).sum(dim=1).mean(dim=0)
+ else:
+ raise ValueError("Keyword 'reduction' must be one of ['none', 'mean', 'sum']")
+
+def cos_loss(input, target):
+ # input = F.softmax(input, dim=-1)
+ cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
+ similarity = cos(input, target)
+ loss = 1 - similarity.mean()
+ return loss
+
+def cos_loss_clamp(input, target):
+ # input = F.softmax(input, dim=-1)*(1 + 2*0.001) - 0.001
+ input = input*(1 + 2*0.001) - 0.001
+ input = torch.clamp(input, 0, 1)
+ cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
+ similarity = cos(input, target)
+ loss = 1 - similarity.mean()
+ return loss
\ No newline at end of file
diff --git a/Anymate/utils/render_utils.py b/Anymate/utils/render_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..db98616131cb43dcc10c3762fbb9ee85533e0b2e
--- /dev/null
+++ b/Anymate/utils/render_utils.py
@@ -0,0 +1,1169 @@
+import bpy
+import numpy as np
+from mathutils import Vector, Matrix
+from tqdm import tqdm
+import glob
+import os
+import torch
+from PIL import Image
+import matplotlib.pyplot as plt
+cmap = plt.get_cmap('viridis')
+import torch
+import torchvision.io as io
+import cv2
+import trimesh
+
+def get_data(ids, root, animate=False, shift_rig=True, id2=None, rignet=False):
+ dataset= torch.load('/data2/aod/testJointDataSet_9.pt')
+ joints = []
+ conns = []
+ skins = []
+
+ for id in ids:
+ if id2 is None:
+ for data in dataset:
+ if id in data['name']:
+ print(data['name'])
+ break
+ else:
+ for data in dataset:
+ if id2 in data['name']:
+ print(data['name'])
+ break
+
+ joint = torch.tensor(torch.load(root + '/joints/' + id + '.pt')).cpu()
+ if shift_rig and id2 is None:
+ y_max = data['points_cloud'][:,1].max()
+ joint = joint/2 + torch.tensor([0,y_max/2,0])
+ temp = joint[:, 1].clone()
+ joint[:, 1] = -joint[:, 2]
+ joint[:, 2] = temp
+
+ conn = torch.tensor(torch.load(root + '/connectivity/' + id + '.pt')).long()
+ if not animate:
+ skin = torch.load(root + '/skinning/' + id + '.pt')
+ if rignet:
+ skins.append(skin[0])
+ elif id2 is None:
+ skins.append(skin[0].softmax(dim=-1).cpu().numpy())
+ else:
+ skins.append(skin)
+
+ joints.append(joint)
+ conns.append(conn)
+
+ return joints, conns, skins
+
+def index_to_sparse(index, weight, shape):
+ sparse_matrix = np.zeros([shape[0], shape[1], shape[2]+1])
+
+ row_indices, col_indices = np.meshgrid(np.arange(sparse_matrix.shape[0]), np.arange(sparse_matrix.shape[1]), indexing='ij')
+
+ row_indices = np.expand_dims(row_indices, axis=-1)
+ col_indices = np.expand_dims(col_indices, axis=-1)
+
+ sparse_matrix[row_indices, col_indices, index] = weight
+
+
+ return torch.from_numpy(sparse_matrix[:, :, :-1])
+
+def get_gt(ids, root):
+ dataset= torch.load('/data2/aod/testJointDataSet_9.pt')
+ joints = []
+ conns = []
+ skins = []
+
+ for id in ids:
+ for data in dataset:
+ if id in data['name']:
+ print(data['name'])
+ break
+
+ joint = data['joints_matrix'][:data['joints_num'], :3]
+ y_max = data['points_cloud'][:,1].max()
+ joint = joint/2 + torch.tensor([0,y_max/2,0])
+ temp = joint[:, 1].clone()
+ joint[:, 1] = -joint[:, 2]
+ joint[:, 2] = temp
+
+ conn = data['parent_index'][:data['joints_num']].long().unsqueeze(1)
+
+ skin = index_to_sparse(data['skin_index'].unsqueeze(0), data['skin_weight'].unsqueeze(0), [1, 8192, data['joints_num']])
+
+ joints.append(joint)
+ conns.append(conn)
+ skins.append(skin[0])
+
+ return joints, conns, skins
+
+def empty():
+ bpy.ops.wm.read_homefile(use_empty=True)
+ # Delete all mesh objects from the scene
+ # for obj in bpy.context.scene.objects:
+ # bpy.data.objects.remove(obj, do_unlink=True)
+
+def add_mesh(filepath, co=None, tex=False, color=(0.5, 0.5, 0.5, 1)):
+ bpy.ops.wm.obj_import(filepath=filepath)
+ obj = bpy.context.object
+
+ if not tex:
+ # give the mesh a material
+ bpy.context.view_layer.objects.active = obj
+ bpy.ops.object.shade_smooth()
+ bpy.ops.object.mode_set(mode='EDIT')
+ bpy.ops.mesh.select_all(action='SELECT')
+ bpy.ops.mesh.normals_make_consistent(inside=False)
+ bpy.ops.object.mode_set(mode='OBJECT')
+ mat = bpy.data.materials.new(name='mat')
+ obj.data.materials.clear()
+ obj.data.materials.append(mat)
+ mat.use_nodes = True
+ mat.node_tree.nodes.clear()
+ bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
+ output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
+ mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
+ mat.node_tree.nodes['Principled BSDF'].inputs['Roughness'].default_value = 0.8
+ # mat.node_tree.nodes['Principled BSDF'].inputs['Specular'].default_value = 0.5
+ # mat.node_tree.nodes['Principled BSDF'].inputs['Metallic'].default_value = 0.5
+ mat.node_tree.nodes['Principled BSDF'].inputs['Base Color'].default_value = color
+ if co is not None:
+ obj.parent = co
+
+def create_sphere(location, size=0.01, color=(1.0, 0.0, 0.0, 1.0), reduced=False):
+ if reduced:
+ bpy.ops.mesh.primitive_uv_sphere_add(radius=size, location=location, segments=8, ring_count=4)
+ else:
+ bpy.ops.mesh.primitive_uv_sphere_add(radius=size, location=location)
+ sphere = bpy.context.active_object
+
+ material_name = f"ColorMaterial_{color}"
+ material = bpy.data.materials.get(material_name)
+
+ if not material:
+ material = bpy.data.materials.new(name=material_name)
+ material.use_nodes = True
+ material.node_tree.nodes.clear()
+ bsdf = material.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
+ output = material.node_tree.nodes.new('ShaderNodeOutputMaterial')
+ material.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
+ material.node_tree.nodes['Principled BSDF'].inputs['Base Color'].default_value = color
+
+ sphere.data.materials.append(material)
+
+ return sphere
+
+def add_co(location=(0,0,0), rotation=(0,0,0), scale=(1,1,1)):
+ co = bpy.data.objects.new("CoordinateSystem", None)
+ bpy.context.collection.objects.link(co)
+ bpy.context.view_layer.objects.active = co
+ co.empty_display_size = 0.1
+ co.empty_display_type = 'ARROWS'
+ co.location = location
+ co.rotation_euler = rotation
+ co.scale = scale
+
+ return co
+
+def add_joint(joints_matrix, co=None):
+
+ for i, joint in enumerate(joints_matrix):
+ sphere = create_sphere((joint[0], joint[1], joint[2]), size=0.01)
+ if co is not None:
+ sphere.parent = co
+
+def create_blue_cone(base_point, apex_point, radius=0.1):
+ # Calculate the radius and length of the cone
+ direction = apex_point - base_point
+ length = direction.length
+
+ # Create cone mesh
+ bpy.ops.mesh.primitive_cone_add(vertices=32, radius1=radius, depth=length, location=(base_point + direction * 0.5))
+ cone = bpy.context.active_object
+
+ # Create or get the blue material
+ blue_material = bpy.data.materials.get("BlueMaterial")
+ if not blue_material:
+ blue_material = bpy.data.materials.new(name="BlueMaterial")
+ blue_material.use_nodes = True
+ blue_material.node_tree.nodes.clear()
+ bsdf = blue_material.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
+ output = blue_material.node_tree.nodes.new('ShaderNodeOutputMaterial')
+ blue_material.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
+ blue_material.node_tree.nodes['Principled BSDF'].inputs['Base Color'].default_value = (0.0, 0.0, 1.0, 1.0)
+
+ cone.data.materials.append(blue_material)
+
+ # Set the cone's orientation
+ cone.rotation_euler = direction.to_track_quat('Z', 'Y').to_euler()
+
+ return cone
+
+def add_conn(con_index, joints_matrix, co=None):
+ for i, parent in enumerate(con_index):
+ parent = parent.item()
+ if parent != i:
+ parent_co = Vector((joints_matrix[parent][0], joints_matrix[parent][1], joints_matrix[parent][2]))
+ position = Vector((joints_matrix[i][0], joints_matrix[i][1], joints_matrix[i][2]))
+ cone = create_blue_cone(parent_co, position, radius=0.008)
+ if co is not None:
+ cone.parent = co
+
+def merge_images(img1, img2, output_path, alpha=1):
+ image_mesh = Image.open(img1)
+ image_rig = Image.open(img2)
+
+ if alpha == 1:
+ image_mesh.paste(image_rig, (0, 0), image_rig)
+ image_mesh.save(output_path)
+ return
+
+ data = image_rig.getdata()
+ data2 = image_mesh.getdata()
+ new_data = []
+ for item, item2 in zip(data, data2):
+ if item[3] == 0:
+ new_data.append(item2)
+ else:
+ new_data.append((int(item[0]*alpha + item2[0]*(1-alpha)), int(item[1]*alpha + item2[1]*(1-alpha)), int(item[2]*alpha + item2[2]*(1-alpha)), 255))
+ image_mesh.putdata(new_data)
+
+ # image_mesh.paste(image_rig, (0, 0), image_rig)
+
+ image_mesh.save(output_path)
+
+def merge_videos(video1, video2, output_path):
+
+ # overlap two videos together, video1 is the background, video2 is the foreground
+ # os.system(f'ffmpeg -i {video1} -i {video2} -filter_complex "[0:v][1:v] overlay=0:0:enable=\'between(t,0,60)\'" -pix_fmt yuv420p -c:a copy {output_path}')
+
+ frames_path_1 = glob.glob(video1 + '*.png')
+ total_frames = len(frames_path_1)
+ combined_frames = []
+ for i in range(total_frames):
+ frame1 = Image.open(f'{video1}{i:04d}.png')
+ frame2 = Image.open(f'{video2}{i:04d}.png')
+ frame1.paste(frame2, (0, 0), frame2)
+ combined_frames.append(frame1)
+
+ # paste the combined frames on a pure white background
+ combined_frames_white = []
+ for frame in combined_frames:
+ white = Image.new('RGB', frame.size, (255, 255, 255))
+ white.paste(frame, (0, 0), frame)
+ combined_frames_white.append(white)
+
+ combined_frames=combined_frames_white
+
+ combined_videos = torch.stack([torch.tensor(np.array(frame)) for frame in combined_frames])[..., :3]
+
+ # write the video with high quality
+ # io.write_video(output_path, combined_videos, 24)
+ io.write_video(output_path, combined_videos, 24, video_codec='libx264', options={'crf': '18'})
+
+ # comvert the frames to mp4 video
+
+ # video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'H264'), 30, (frame1.size[0], frame1.size[1]))
+ # for frame in combined_frames:
+ # video.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
+ # video.release()
+
+ # video_1, audio_1, fps_1 = io.read_video(video1, pts_unit="sec")
+ # video_2, audio_2, fps_2 = io.read_video(video2, pts_unit="sec")
+ # non_zero = video_2.sum(dim=-1) != 0
+ # non_zero = torch.stack([non_zero, non_zero, non_zero], dim=-1)
+ # video_1[non_zero] = video_2[non_zero]
+ # io.write_video(output_path, video_1, int(fps_1['video_fps']))
+
+def add_skin(filepath, skin, bone_index, co=None, pc=None):
+ bpy.ops.wm.obj_import(filepath=filepath)
+ obj = bpy.context.object
+
+ bpy.context.view_layer.objects.active = obj
+ bpy.ops.object.shade_smooth()
+ bpy.ops.object.mode_set(mode='EDIT')
+ bpy.ops.mesh.select_all(action='SELECT')
+ bpy.ops.mesh.normals_make_consistent(inside=False)
+ bpy.ops.object.mode_set(mode='OBJECT')
+
+ if co is not None:
+ obj.parent = co
+
+ if pc is not None:
+ skin = np.array(skin)
+ pc = pc[:, :3].numpy()
+ y_max = pc[:, 1].max()
+ pc = pc + np.array([0, y_max, 0])
+ pc = pc / 2
+ new_skin = np.zeros((len(obj.data.vertices), skin.shape[1]))
+ for i, v in enumerate(obj.data.vertices):
+ v_co = np.array(v.co)
+
+ dist = np.linalg.norm(pc - v_co, axis=1)
+ # min_idx = np.argmin(dist)
+ # sort, and then get top 3 index
+ min_idx_list = np.argsort(dist)[:3]
+
+ for min_idx in min_idx_list:
+ # get inverse distance weight
+ interpolate_weight = np.square(1 / dist[min_idx]) / np.square(1 / dist[min_idx_list]).sum()
+ new_skin[i] = new_skin[i] + interpolate_weight * skin[min_idx]
+
+ skin = new_skin
+
+ color_list = skin
+
+ color_list = color_list[:,bone_index]
+
+ vertex_colors = obj.data.vertex_colors.new()
+
+ for poly in obj.data.polygons:
+ for loop_index in poly.loop_indices:
+
+ vertex_index = obj.data.loops[loop_index].vertex_index
+ # Get the weight for the vertex
+ weight = color_list[vertex_index]
+
+ color = cmap(weight)
+
+ # Assign the weight to the vertex color (RGBA)
+ vertex_colors.data[loop_index].color = color # Use the weight for RGB
+
+ # let bsdf use vertex color and then output to surface
+ mat = bpy.data.materials.new(name='mat')
+ # delete all material of obj
+ obj.data.materials.clear()
+ obj.data.materials.append(mat)
+ mat.use_nodes = True
+ mat.node_tree.nodes.clear()
+ vertex_color = mat.node_tree.nodes.new('ShaderNodeVertexColor')
+ bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
+ output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
+ mat.node_tree.links.new(vertex_color.outputs['Color'], bsdf.inputs['Base Color'])
+ mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
+ mat.node_tree.nodes['Principled BSDF'].inputs['Roughness'].default_value = 0.5
+
+
+
+def add_pc(points):
+ base_sphere = create_sphere((points[0][0], points[0][1], points[0][2]), size=0.003, color=cmap(0), reduced=True)
+ # copy the base sphere to create the rest of the spheres
+ for i in tqdm(range(1, points.shape[0])):
+ new_sphere = base_sphere.copy()
+ new_sphere.location = (points[i][0], points[i][1], points[i][2])
+ bpy.context.collection.objects.link(new_sphere)
+
+def add_floor(back=False):
+ # create a plane as floor
+ bpy.ops.mesh.primitive_plane_add(size=50, enter_editmode=False, align='WORLD', location=(0, 20, 0))
+ floor = bpy.context.object
+ floor.name = 'floor'
+ # set white material for floor
+ mat = bpy.data.materials.new(name='floor_mat')
+ floor.data.materials.append(mat)
+ mat.use_nodes = True
+ mat.node_tree.nodes.clear()
+ bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
+ output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
+ mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
+ mat.node_tree.nodes['Diffuse BSDF'].inputs['Color'].default_value = (1, 1, 1, 1)
+
+ if back:
+ # create a plane as background
+ bpy.ops.mesh.primitive_plane_add(size=30, enter_editmode=False, align='WORLD', location=(0, 15, 0), rotation=(-0.5*np.pi, 0, 0))
+ background = bpy.context.object
+ background.name = 'background'
+ # set white material for background
+ mat = bpy.data.materials.new(name='background_mat')
+ background.data.materials.append(mat)
+ mat.use_nodes = True
+ mat.node_tree.nodes.clear()
+ bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
+ output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
+ mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
+ mat.node_tree.nodes['Diffuse BSDF'].inputs['Color'].default_value = (1, 1, 1, 1)
+
+def setup_render():
+ # color management
+ bpy.context.scene.view_settings.view_transform = 'Standard'
+
+ # set the render engine to Cycles
+ bpy.context.scene.render.engine = 'CYCLES'
+ # enable cuda
+ bpy.context.preferences.addons['cycles'].preferences.get_devices()
+ bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
+ bpy.context.scene.cycles.device = 'GPU'
+
+ # set render background to transparent
+ bpy.context.scene.render.film_transparent = True
+
+def render(output_path, shadow=True, shading=True, quick=False):
+
+ if shadow:
+ add_floor()
+
+ if shading:
+ # create a sun light
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
+ light = bpy.context.object
+ light.data.energy = 5
+ # angle pointing to the origin
+ light.rotation_euler = (0.1*np.pi, 0, 0)
+ # set angle
+ light.data.angle = 0.08*np.pi
+
+ else:
+ # global illumination by create world light
+ world = bpy.data.worlds.new('World')
+ bpy.context.scene.world = world
+ world.use_nodes = True
+ world_light = world.node_tree.nodes['Background']
+ world_light.inputs['Strength'].default_value = 1
+ world_light.inputs['Color'].default_value = (1, 1, 1, 1)
+
+ # create a camera
+ cam = bpy.data.cameras.new("Camera")
+ cam_ob = bpy.data.objects.new("Camera", cam)
+ camera = bpy.data.objects['Camera']
+ bpy.context.scene.collection.objects.link(camera)
+ camera.location = Vector((2, -1.5, 2))
+ look_at = Vector((0, 0, 0.36))
+ # compute the rotation
+ camera.rotation_mode = 'QUATERNION'
+ camera.rotation_quaternion = (camera.location - look_at).to_track_quat('Z', 'Y')
+ # set size
+ camera.data.sensor_width = 26
+ # set the camera to be active
+ bpy.context.scene.camera = camera
+
+
+
+ # make the rendered image square
+ bpy.context.scene.render.resolution_x = 2048
+ bpy.context.scene.render.resolution_y = 2048
+
+ setup_render()
+
+ if quick:
+ # reduce the number of samples
+ bpy.context.scene.cycles.samples = 128
+ bpy.context.scene.cycles.preview_samples = 128
+ bpy.context.scene.cycles.max_bounces = 1
+ bpy.context.scene.cycles.min_bounces = 1
+ bpy.context.scene.cycles.diffuse_bounces = 1
+ bpy.context.scene.cycles.glossy_bounces = 1
+ else:
+ bpy.context.scene.cycles.samples = 1024
+ bpy.context.scene.cycles.preview_samples = 1024
+ bpy.context.scene.cycles.max_bounces = 4
+ bpy.context.scene.cycles.min_bounces = 4
+ bpy.context.scene.cycles.diffuse_bounces = 4
+ bpy.context.scene.cycles.glossy_bounces = 4
+
+ # output path
+ # output_path = '/home/ydengbd/objaverse/test.png'
+ bpy.context.scene.render.filepath = output_path
+ bpy.ops.render.render(write_still=True)
+
+def render_spin(output_path, co, shadow=True, shading=True, quick=False):
+ # create a new coordinate system at the origin
+ new_co = add_co(location=(0, 0, 0), rotation=(0, 0, 0), scale=(1, 1, 1))
+ # set the object to be the child of the new coordinate system
+ co.parent = new_co
+
+ # add spin animation to the new coordinate system
+ new_co.rotation_mode = 'XYZ'
+ new_co.rotation_euler = (0, 0, 0)
+ new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=0)
+ new_co.rotation_euler = (0, 0, 2*np.pi)
+ new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=60)
+
+ if shadow:
+ add_floor()
+
+ if shading:
+ # create a sun light
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
+ light = bpy.context.object
+ light.data.energy = 5
+ # angle pointing to the origin
+ light.rotation_euler = (0.1*np.pi, 0, 0)
+ # set angle
+ light.data.angle = 0.08*np.pi
+
+ else:
+ # global illumination by create world light
+ world = bpy.data.worlds.new('World')
+ bpy.context.scene.world = world
+ world.use_nodes = True
+ world_light = world.node_tree.nodes['Background']
+ world_light.inputs['Strength'].default_value = 1
+ world_light.inputs['Color'].default_value = (1, 1, 1, 1)
+
+ # create a camera
+ cam = bpy.data.cameras.new("Camera")
+ cam_ob = bpy.data.objects.new("Camera", cam)
+ camera = bpy.data.objects['Camera']
+ bpy.context.scene.collection.objects.link(camera)
+ camera.location = Vector((2, -1.5, 2))
+ look_at = Vector((0, 0, 0.36))
+ # compute the rotation
+ camera.rotation_mode = 'QUATERNION'
+ camera.rotation_quaternion = (camera.location - look_at).to_track_quat('Z', 'Y')
+ # set size
+ camera.data.sensor_width = 26
+ # set the camera to be active
+ bpy.context.scene.camera = camera
+
+
+ # render the animation
+ bpy.context.scene.frame_start = 0
+ bpy.context.scene.frame_end = 60
+
+ # make the rendered image square
+ bpy.context.scene.render.resolution_x = 1024
+ bpy.context.scene.render.resolution_y = 1024
+
+ setup_render()
+
+ if quick:
+ # reduce the number of samples
+ bpy.context.scene.cycles.samples = 128
+ bpy.context.scene.cycles.preview_samples = 128
+ bpy.context.scene.cycles.max_bounces = 1
+ bpy.context.scene.cycles.min_bounces = 1
+ bpy.context.scene.cycles.diffuse_bounces = 1
+ bpy.context.scene.cycles.glossy_bounces = 1
+ else:
+ bpy.context.scene.cycles.samples = 512
+ bpy.context.scene.cycles.preview_samples = 512
+ bpy.context.scene.cycles.max_bounces = 4
+ bpy.context.scene.cycles.min_bounces = 4
+ bpy.context.scene.cycles.diffuse_bounces = 4
+ bpy.context.scene.cycles.glossy_bounces = 4
+
+ # output path
+ bpy.context.scene.render.filepath = output_path
+ if output_path.endswith('.mp4'):
+ # render a mp4 video
+ bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
+ bpy.context.scene.render.ffmpeg.format = 'MPEG4'
+ bpy.context.scene.render.ffmpeg.codec = 'H264'
+
+ bpy.ops.render.render(animation=True, write_still=True)
+
+def setup_anim(armature, arti):
+ # enter pose mode
+ print('Arti shape', arti.shape)
+ bpy.ops.object.mode_set(mode='POSE')
+ print('total bones', len(armature.pose.bones))
+ for i, pose_bone in enumerate(armature.pose.bones):
+ pose_bone.rotation_mode = 'XYZ'
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=0)
+
+ pose_bone.rotation_euler = arti[i]
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=30)
+
+ pose_bone.rotation_euler = Vector((0, 0, 0))
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=60)
+ bpy.ops.object.mode_set(mode='OBJECT')
+
+def render_anim(output_path, armature, arti, quick=False):
+ # enter pose mode
+ setup_anim(armature, arti)
+
+ # save blend file
+ # bpy.ops.wm.save_as_mainfile(filepath='/data2/ydengbd/objaverse/test.blend')
+
+ add_floor()
+
+ # create a sun light
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
+ light = bpy.context.object
+ light.data.energy = 5
+ # angle pointing to the origin
+ light.rotation_euler = (50/180*np.pi, 0, -20/180*np.pi)
+ # set angle
+ light.data.angle = 12/180*np.pi
+
+ # create a camera
+ cam = bpy.data.cameras.new("Camera")
+ cam_ob = bpy.data.objects.new("Camera", cam)
+ camera = bpy.data.objects['Camera']
+ bpy.context.scene.collection.objects.link(camera)
+ camera.location = Vector((0, -3, 1.3))
+ camera.rotation_euler = Vector((1.309, 0, 0))
+ # set size
+ camera.data.sensor_width = 36
+ # set the camera to be active
+ bpy.context.scene.camera = camera
+
+ # render the animation
+ bpy.context.scene.frame_start = 0
+ bpy.context.scene.frame_end = 60
+
+ # make the rendered image square
+ bpy.context.scene.render.resolution_x = 1920
+ bpy.context.scene.render.resolution_y = 1080
+
+ setup_render()
+
+ if quick:
+ # reduce the number of samples
+ bpy.context.scene.cycles.samples = 128
+ bpy.context.scene.cycles.preview_samples = 128
+ bpy.context.scene.cycles.max_bounces = 1
+ bpy.context.scene.cycles.min_bounces = 1
+ bpy.context.scene.cycles.diffuse_bounces = 1
+ bpy.context.scene.cycles.glossy_bounces = 1
+ else:
+ bpy.context.scene.cycles.samples = 1024
+ bpy.context.scene.cycles.preview_samples = 1024
+ bpy.context.scene.cycles.max_bounces = 4
+ bpy.context.scene.cycles.min_bounces = 4
+ bpy.context.scene.cycles.diffuse_bounces = 4
+ bpy.context.scene.cycles.glossy_bounces = 4
+
+ # output path
+ bpy.context.scene.render.filepath = output_path
+ if output_path.endswith('.mp4'):
+ # render a mp4 video
+ bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
+ bpy.context.scene.render.ffmpeg.format = 'MPEG4'
+ bpy.context.scene.render.ffmpeg.codec = 'H264'
+
+ bpy.ops.render.render(animation=True, write_still=True)
+
+
+def render_animspin(output_path, co, armature, arti, shadow=True, shading=True, quick=False):
+ # enter pose mode
+ print('Arti shape', arti.shape)
+ bpy.ops.object.mode_set(mode='POSE')
+ print('total bones', len(armature.pose.bones))
+ for i, pose_bone in enumerate(armature.pose.bones):
+ pose_bone.rotation_mode = 'XYZ'
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=0)
+
+ pose_bone.rotation_euler = arti[i]
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=30)
+
+ pose_bone.rotation_euler = Vector((0, 0, 0))
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=60)
+
+ pose_bone.rotation_euler = arti[i]
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=90)
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=150)
+
+ pose_bone.rotation_euler = Vector((0, 0, 0))
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=180)
+ bpy.ops.object.mode_set(mode='OBJECT')
+
+ # create a new coordinate system at the origin
+ new_co = add_co(location=(0, 0, 0), rotation=(0, 0, 0), scale=(1, 1, 1))
+ # set the object to be the child of the new coordinate system
+ co.parent = new_co
+
+ # add spin animation to the new coordinate system
+ new_co.rotation_mode = 'XYZ'
+ new_co.rotation_euler = (0, 0, 0)
+ new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=90)
+ new_co.rotation_euler = (0, 0, 2*np.pi)
+ new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=150)
+
+ if shadow:
+ add_floor()
+
+ if shading:
+ # create a sun light
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
+ light = bpy.context.object
+ light.data.energy = 5
+ # angle pointing to the origin
+ light.rotation_euler = (0.1*np.pi, 0, 0)
+ # set angle
+ light.data.angle = 0.08*np.pi
+
+ else:
+ # global illumination by create world light
+ world = bpy.data.worlds.new('World')
+ bpy.context.scene.world = world
+ world.use_nodes = True
+ world_light = world.node_tree.nodes['Background']
+ world_light.inputs['Strength'].default_value = 1
+ world_light.inputs['Color'].default_value = (1, 1, 1, 1)
+
+ # create a camera
+ cam = bpy.data.cameras.new("Camera")
+ cam_ob = bpy.data.objects.new("Camera", cam)
+ camera = bpy.data.objects['Camera']
+ bpy.context.scene.collection.objects.link(camera)
+ camera.location = Vector((2, -1.5, 2))
+ look_at = Vector((0, 0, 0.36))
+ # compute the rotation
+ camera.rotation_mode = 'QUATERNION'
+ camera.rotation_quaternion = (camera.location - look_at).to_track_quat('Z', 'Y')
+ # set size
+ camera.data.sensor_width = 26
+ # set the camera to be active
+ bpy.context.scene.camera = camera
+
+
+ # render the animation
+ bpy.context.scene.frame_start = 0
+ bpy.context.scene.frame_end = 180
+
+ # make the rendered image square
+ bpy.context.scene.render.resolution_x = 1024
+ bpy.context.scene.render.resolution_y = 1024
+
+ setup_render()
+
+ if quick:
+ # reduce the number of samples
+ bpy.context.scene.cycles.samples = 128
+ bpy.context.scene.cycles.preview_samples = 128
+ bpy.context.scene.cycles.max_bounces = 1
+ bpy.context.scene.cycles.min_bounces = 1
+ bpy.context.scene.cycles.diffuse_bounces = 1
+ bpy.context.scene.cycles.glossy_bounces = 1
+ else:
+ bpy.context.scene.cycles.samples = 512
+ bpy.context.scene.cycles.preview_samples = 512
+ bpy.context.scene.cycles.max_bounces = 4
+ bpy.context.scene.cycles.min_bounces = 4
+ bpy.context.scene.cycles.diffuse_bounces = 4
+ bpy.context.scene.cycles.glossy_bounces = 4
+
+ # output path
+ bpy.context.scene.render.filepath = output_path
+ if output_path.endswith('.mp4'):
+ # render a mp4 video
+ bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
+ bpy.context.scene.render.ffmpeg.format = 'MPEG4'
+ bpy.context.scene.render.ffmpeg.codec = 'H264'
+
+ bpy.ops.render.render(animation=True, write_still=True)
+
+def render_scene(output_path, shadow=True):
+
+ if shadow:
+ add_floor()
+
+
+ # create a sun light
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
+ light = bpy.context.object
+ light.data.energy = 5
+ # angle pointing to the origin
+ light.rotation_euler = (50/180*np.pi, 0, -20/180*np.pi)
+ # set angle
+ light.data.angle = 12/180*np.pi
+
+ # create a camera
+ cam = bpy.data.cameras.new("Camera")
+ cam_ob = bpy.data.objects.new("Camera", cam)
+ camera = bpy.data.objects['Camera']
+ bpy.context.scene.collection.objects.link(camera)
+ camera.location = Vector((0, -10, 5))
+ camera.rotation_euler = Vector((1.22, 0, 0))
+ # set size
+ camera.data.sensor_width = 26
+ # set the camera to be active
+ bpy.context.scene.camera = camera
+
+
+
+ # make the rendered image square
+ bpy.context.scene.render.resolution_x = 1920
+ bpy.context.scene.render.resolution_y = 1080
+
+ setup_render()
+
+
+
+ # output path
+ # output_path = '/home/ydengbd/objaverse/test.png'
+ bpy.context.scene.render.filepath = output_path
+ bpy.ops.render.render(write_still=True)
+
+
+def render_teaser(output_path, shadow=True, quick=False):
+
+ if shadow:
+ add_floor(back=True)
+
+ # create a sun light
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
+ light = bpy.context.object
+ light.data.energy = 5
+ # angle pointing to the origin
+ light.rotation_euler = (50/180*np.pi, 0, -20/180*np.pi)
+ # set angle
+ light.data.angle = 12/180*np.pi
+
+ # create a camera
+ cam = bpy.data.cameras.new("Camera")
+ cam_ob = bpy.data.objects.new("Camera", cam)
+ camera = bpy.data.objects['Camera']
+ bpy.context.scene.collection.objects.link(camera)
+ camera.location = Vector((0, -3, 1.3))
+ camera.rotation_euler = Vector((80/180*np.pi, 0, 0))
+ # set size
+ camera.data.sensor_width = 48
+ # set the camera to be active
+ bpy.context.scene.camera = camera
+
+ # render the animation
+ bpy.context.scene.frame_start = 0
+ bpy.context.scene.frame_end = 60
+
+ # make the rendered image square
+ bpy.context.scene.render.resolution_x = 2400
+ bpy.context.scene.render.resolution_y = 1080
+
+ setup_render()
+
+ if quick:
+ # reduce the number of samples
+ bpy.context.scene.cycles.samples = 128
+ bpy.context.scene.cycles.preview_samples = 128
+ bpy.context.scene.cycles.max_bounces = 1
+ bpy.context.scene.cycles.min_bounces = 1
+ bpy.context.scene.cycles.diffuse_bounces = 1
+ bpy.context.scene.cycles.glossy_bounces = 1
+ else:
+ bpy.context.scene.cycles.samples = 1024
+ bpy.context.scene.cycles.preview_samples = 1024
+ bpy.context.scene.cycles.max_bounces = 4
+ bpy.context.scene.cycles.min_bounces = 4
+ bpy.context.scene.cycles.diffuse_bounces = 4
+ bpy.context.scene.cycles.glossy_bounces = 4
+
+ # output path
+ bpy.context.scene.render.filepath = output_path
+ if output_path.endswith('.mp4'):
+ # render a mp4 video
+ bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
+ bpy.context.scene.render.ffmpeg.format = 'MPEG4'
+ bpy.context.scene.render.ffmpeg.codec = 'H264'
+
+ bpy.ops.render.render(animation=True, write_still=True)
+
+def setup_armature(path, tex=False, save=True):
+ joints_matrix = torch.load(os.path.join(path, 'joints.pt'))
+ connectivity = torch.load(os.path.join(path, 'conns.pt'))
+ skinning_weights = torch.load(os.path.join(path, 'skins.pt'))
+ obj_file_path = os.path.join(path, 'object.obj')
+
+ # bpy.ops.wm.obj_import(filepath=obj_file_path)
+ add_mesh(obj_file_path, tex=tex)
+ mesh_object = bpy.context.selected_objects[0]
+
+ # pack textures
+ bpy.ops.file.pack_all()
+
+ temp = torch.tensor(joints_matrix)[:, 1].clone()
+ joints_matrix[:, 1] = -joints_matrix[:, 2]
+ joints_matrix[:, 2] = temp
+
+ bpy.ops.object.armature_add()
+ armature_obj = bpy.context.object
+
+
+ bpy.ops.object.mode_set(mode='EDIT')
+ bpy.ops.armature.select_all(action='SELECT')
+ bpy.ops.armature.delete()
+
+ world_matrix = Matrix([[1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]])
+ armature_obj.matrix_world = world_matrix
+
+ bone_dict = {}
+
+ i_name = 0
+
+ for i in range(len(joints_matrix)):
+
+ if connectivity[i] == i:
+ continue
+ bone_name = str(i_name)
+ bone = armature_obj.data.edit_bones.new(bone_name)
+ bone.head = joints_matrix[connectivity[i]].cpu().numpy()
+ bone.tail = joints_matrix[i].cpu().numpy()
+ bone_dict[bone_name] = bone
+ i_name += 1
+
+ for bone_name, bone in bone_dict.items():
+ # Find parent bone by checking if current bone's head matches any other bone's tail
+ for other_bone_name, other_bone in bone_dict.items():
+ if other_bone != bone and bone.head == other_bone.tail:
+ bone.parent = other_bone
+ break
+
+ assert i_name == skinning_weights.shape[1]
+
+ for i, skinning_weight in enumerate(skinning_weights):
+ # print("skinning_weight", skinning_weight)
+ vertex_index = i
+ for j,weight in enumerate(skinning_weight):
+ bone_name = str(j)
+ bone_weight = float(weight)
+
+ vertex_group_name = f"{bone_name}"
+ vertex_group = mesh_object.vertex_groups.get(vertex_group_name)
+ if vertex_group is None:
+ vertex_group = mesh_object.vertex_groups.new(name=vertex_group_name)
+ vertex_group.add([vertex_index], bone_weight, 'ADD')
+
+ # for obj in bpy.context.scene.objects:
+ # if obj.type == 'MESH':
+ modifier = mesh_object.modifiers.new(name="Armature", type='ARMATURE')
+ modifier.object = armature_obj
+ modifier.use_vertex_groups = True
+ print("Armature modifier added to mesh:", mesh_object.name)
+
+ bpy.ops.object.mode_set(mode='OBJECT')
+ if save:
+ bpy.ops.wm.save_as_mainfile(filepath= os.path.join(path, 'blender_output.blend'))
+
+ return armature_obj
+
+def reload_tensor_skinning(data, bone_name_list):
+
+ # with open(json_file, "r") as f:
+ # skinning_data = json.load(f)
+
+ armature_obj = bpy.data.objects.get("Armature")
+ if not armature_obj:
+ print("Error: Armature object 'Armature' not found.")
+ return
+
+ # 将所有网格对象放置在骨骼对象的子集中
+ count = 0
+ for obj in bpy.context.scene.objects:
+ if obj.type == 'MESH':
+ obj.parent = armature_obj
+ count += 1
+
+ print("total mesh count:", count)
+
+ for obj in bpy.context.scene.objects:
+ vertex_index = 0
+ if obj.type == 'MESH':
+ # mesh_name = obj.name
+ # if mesh_name in skinning_data:
+ # skinning_info = skinning_data[mesh_name]
+ # if "weight" in skinning_info:
+ # print("Applying skinning data for mesh:", mesh_name)
+ # vertex_index = 0
+ # for vertex_weight in skinning_info["weight"]:
+ # for bone_name, weight_value in vertex_weight.items():
+ # vertex_group = obj.vertex_groups.get(bone_name)
+ # if vertex_group is None:
+ # vertex_group = obj.vertex_groups.new(name=bone_name)
+ # print("Vertex group created:", bone_name)
+ # vertex_group.add([vertex_index], weight_value, 'REPLACE')
+ # vertex_index += 1
+ # else:
+ # print("No skinning data found for mesh:", mesh_name)
+
+ for i, v in enumerate(obj.data.vertices):
+ v_co = np.array(v.co)
+ pc = data['pc'][:, :3].numpy()
+ y_max = pc[:, 1].max()
+ pc = pc + np.array([0, y_max, 0])
+ pc = pc / 2
+ dist = np.linalg.norm(pc - v_co, axis=1)
+ # min_idx = np.argmin(dist)
+ # sort, and then get top 3 index
+ min_idx_list = np.argsort(dist)[:3]
+
+ for min_idx in min_idx_list:
+ # get inverse distance weight
+ interpolate_weight = np.square(1 / dist[min_idx]) / np.square(1 / dist[min_idx_list]).sum()
+
+ for idx, j in enumerate(data['skins_index'][min_idx]):
+ if j == -1:
+ break
+ bone_name = bone_name_list[j]
+ vertex_group = obj.vertex_groups.get(str(int(bone_name)))
+ if vertex_group is None:
+ vertex_group = obj.vertex_groups.new(name=str(int(bone_name)))
+ print("Vertex group created:", bone_name)
+
+ vertex_group.add([i], interpolate_weight * data['skins_weight'][min_idx][idx], 'ADD')
+
+
+ for obj in bpy.context.scene.objects:
+ if obj.type == 'MESH':
+ modifier = obj.modifiers.new(name="Armature", type='ARMATURE')
+ modifier.object = armature_obj
+ modifier.use_vertex_groups = True
+ print("Armature modifier added to mesh:", obj.name)
+
+def reload_tensor(data, root='data', save=True):
+ joints_matrix = data['joints'].clone()
+ connectivity = data['conns']
+ obj_file_path = os.path.join(root, data['name'], 'object.obj')
+
+ # bpy.ops.wm.obj_import(filepath=obj_file_path)
+ add_mesh(obj_file_path)
+ mesh_object = bpy.context.selected_objects[0]
+
+ # pack textures
+ bpy.ops.file.pack_all()
+
+ y_max = data['pc'][:, 1].max()
+ joints_matrix = joints_matrix + torch.tensor([0, y_max, 0])
+ joints_matrix = joints_matrix / 2
+
+ temp = joints_matrix[:, 1].clone()
+ joints_matrix[:, 1] = -joints_matrix[:, 2]
+ joints_matrix[:, 2] = temp
+
+ bpy.ops.object.armature_add()
+ armature_obj = bpy.context.object
+
+
+ bpy.ops.object.mode_set(mode='EDIT')
+ bpy.ops.armature.select_all(action='SELECT')
+ bpy.ops.armature.delete()
+
+ world_matrix = Matrix([[1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]])
+ armature_obj.matrix_world = world_matrix
+
+ bone_dict = {}
+ bone_name_list = np.zeros(data['bones_num'])
+ i_name = 0
+
+ for i in range(len(joints_matrix)):
+
+ if connectivity[i] == i:
+ continue
+ bone_name = str(i_name)
+ bone = armature_obj.data.edit_bones.new(bone_name)
+ bone.head = joints_matrix[connectivity[i]].cpu().numpy()
+ bone.tail = joints_matrix[i].cpu().numpy()
+ bone_dict[bone_name] = bone
+ for j, skinbone in enumerate(data['bones']):
+ if torch.equal(skinbone[:3], data['joints'][connectivity[i]]) and torch.equal(skinbone[3:], data['joints'][i]):
+ bone_name_list[j] = i_name
+ i_name += 1
+
+ for bone_name, bone in bone_dict.items():
+ # Find parent bone by checking if current bone's head matches any other bone's tail
+ for other_bone_name, other_bone in bone_dict.items():
+ if other_bone != bone and bone.head == other_bone.tail:
+ bone.parent = other_bone
+ break
+
+ print(bone_name_list)
+
+ reload_tensor_skinning(data, bone_name_list)
+
+ print("Armature modifier added to mesh:", mesh_object.name)
+
+ bpy.ops.object.mode_set(mode='OBJECT')
+ if save:
+ bpy.ops.wm.save_as_mainfile(filepath= os.path.join('/data2/ydengbd/Anymate/Anymate/data', data['name'], 'blender_output.blend'))
+
+ return armature_obj
+
+def load_blender(blender_path):
+
+ bpy.ops.wm.read_homefile(use_empty=True)
+ # bpy.ops.wm.append(directory=object_path, link=False)
+ # load_object(object_path)
+ bpy.ops.wm.open_mainfile(filepath=blender_path)
+ armature_obj = []
+ mesh_obj = []
+ for obj in bpy.context.scene.objects:
+ if obj.type == "ARMATURE":
+ armature_obj.append(obj)
+ if obj.type == "MESH":
+ mesh_obj.append(obj)
+
+ print('mesh obj:', len(mesh_obj))
+
+
+
+ # start retrieve the information of mesh, skining and rigging
+
+ #1. retrieve the information of rigging, save the world matrix of the amature object
+ total_armature_info = {}
+ joints_matrix = []
+ bone_dict = {}
+ parent_name= []
+ bone_count = 0
+ for obj in armature_obj:
+ # depsgraph = bpy.context.evaluated_depsgraph_get()
+ # obj = obj.evaluated_get(depsgraph)
+ armature_info = {}
+ armature_info["world_matrix"] = [list(row) for row in obj.matrix_world.copy()]
+ translation = obj.matrix_world.translation
+ for bone in obj.pose.bones:
+
+ joints_matrix.append(np.array(list((obj.matrix_world.to_3x3() @ bone.head+translation).copy())))
+
+ if bone.parent:
+ parent_name.append(bone.parent.name)
+ else:
+ parent_name.append('root')
+ bone_dict[bone.name] = bone_count
+ bone_count += 1
+ connectivity = torch.zeros(bone_count, dtype=torch.int32)
+
+ for i, bone_name in enumerate(parent_name):
+ if bone_name == 'root':
+ connectivity[i] = i
+ else:
+ connectivity[i] = bone_dict[bone_name]
+ joints_matrix = torch.from_numpy(np.array(joints_matrix))
+
+ skinning_weight = torch.zeros(len(mesh_obj[0].data.vertices), joints_matrix.shape[0])
+
+ vertex_index = 0
+ for obj in mesh_obj:
+ vertex_groups = obj.vertex_groups
+
+
+ for vertex in obj.data.vertices:
+ vertex_info = {}
+ for group in vertex.groups:
+ name = vertex_groups[group.group].name
+
+ weight = group.weight
+ skinning_weight[vertex.index][bone_dict[name]] = weight
+
+ obj_save_path = blender_path.replace('.blend', '.obj')
+ bpy.ops.wm.obj_export(filepath=obj_save_path, export_materials=False)
+ return joints_matrix,connectivity, skinning_weight
+
+
+def save_scene(scene_path):
+ # export the scene as a glb file
+ if scene_path.endswith('.glb'):
+ bpy.ops.export_scene.gltf(filepath=scene_path)
+ bpy.ops.wm.save_as_mainfile(filepath=scene_path.replace('.glb', '.blend'))
+ elif scene_path.endswith('.blend'):
+ bpy.ops.wm.save_as_mainfile(filepath=scene_path)
+ elif scene_path.endswith('.obj'):
+ bpy.ops.wm.obj_export(filepath=scene_path, export_materials=False)
+ else:
+ raise ValueError(f"Unsupported file extension: {scene_path}")
+
+if __name__ == '__main__':
+ # load the mesh
+ empty()
+ add_mesh('/home/ydengbd/objaverse/obj/0001.obj')
+ # load the joints
+ joints_matrix = np.load('/home/ydengbd/objaverse/joints/0001.npy')
+ add_joint(joints_matrix)
+ # load the connections
+ con_index = np.load('/home/ydengbd/objaverse/connections/0001.npy')
+ add_conn(con_index)
+ # load the skin
\ No newline at end of file
diff --git a/Anymate/utils/train_utils.py b/Anymate/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..23a96990bccd52ead08e6e329440ed850edfee1d
--- /dev/null
+++ b/Anymate/utils/train_utils.py
@@ -0,0 +1,406 @@
+import os
+import numpy as np
+from tqdm import tqdm
+
+import torch
+import torch.backends.cudnn as cudnn
+from torch.utils.tensorboard import SummaryWriter
+from torch.utils.data import DataLoader
+
+from Anymate.dataset import AnymateDataset, my_collate
+from Anymate.model import EncoderDecoder
+from Anymate.utils.loss_utils import cross_entropy_with_probs_batch, cos_loss, cos_loss_clamp, chamfer_distance_with_average
+from Anymate.utils.vol_utils import get_co, get_gt, extract_keypoints
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed import init_process_group, destroy_process_group
+from torch.utils.data.distributed import DistributedSampler
+
+import point_cloud_utils as pcu
+from sklearn.cluster import DBSCAN
+from diffusers import DDPMScheduler, DDIMScheduler
+import torch.nn.functional as F
+from Anymate.utils.diffusion_utils import my_collate_diff, randn_tensor
+
+
+def ddp_setup(rank: int, world_size: int, port: int):
+ """
+ Args:
+ rank: Unique identifier of each process
+ world_size: Total number of processes
+ """
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(port)
+ torch.cuda.set_device(rank)
+ init_process_group(backend="nccl", rank=rank, world_size=world_size)
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0.0
+ self.avg = 0.0
+ self.sum = 0.0
+ self.count = 0.0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def accumulate(self, val, n=1):
+ self.val = val
+ self.sum += val
+ self.count += n
+ self.avg = self.sum / self.count
+
+def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='model_best.pth.tar', snapshot=None):
+ filepath = os.path.join(checkpoint, filename)
+ if is_best:
+ torch.save(state, filepath)
+
+ if snapshot and state['epoch'] % snapshot == 0:
+ torch.save(state, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state['epoch'])))
+
+def train_model(rank, world_size, config, args, shared_dict, port=12355):
+ ddp_setup(rank, world_size, port)
+ lowest_loss = 1e20
+ model_config = config['model']
+ model = EncoderDecoder(device=f'cuda:{rank}', dtype=torch.float32, **model_config)
+ model.to(f'cuda:{rank}')
+
+ if rank == 0:
+ print('only_embed', model.only_embed)
+ print('return_latents', model.return_latents)
+ print(model)
+ if not args.finetune:
+ model.encoder.requires_grad_(False)
+ model = DDP(model, device_ids=[rank])
+ optimizer_config = config['optimizer']
+ if args.finetune:
+ optimizer = torch.optim.Adam(model.module.parameters(), **optimizer_config)
+ else:
+ if args.encoder == 'miche':
+ optimizer = torch.optim.Adam(model.module.decoder.parameters(), **optimizer_config)
+ elif args.encoder == 'bert':
+ optimizer = torch.optim.Adam(list(model.module.decoder.parameters()) + list(model.module.point_proj.parameters()), **optimizer_config)
+ # optionally resume from a checkpoint
+ if args.resume:
+ try:
+ print("=> loading checkpoint '{}'".format(args.resume))
+ checkpoint = torch.load(args.resume)
+ args.start_epoch = checkpoint['epoch']
+ lowest_loss = checkpoint['lowest_loss']
+ model.module.load_state_dict(checkpoint['state_dict'], strict=True)
+
+ print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
+ except:
+ print("=> no checkpoint found at '{}'".format(args.resume))
+
+ cudnn.benchmark = True
+ print(' Total params: %.2fM' % (sum(p.numel() for p in optimizer.param_groups[0]['params']) / 1000000.0))
+ my_collate_func = my_collate_diff if args.mode == 'diffusion' else my_collate
+ if world_size > 1:
+ if not args.split:
+ train_dataset = shared_dict['train_dataset']
+ train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
+ train_loader = DataLoader(train_dataset, batch_size=args.train_batch, sampler=train_sampler, collate_fn= my_collate_func)
+ else:
+ train_dataset = AnymateDataset(name=args.trainset + f'_{rank}', root=args.root) #should changed to dpp version
+ train_loader = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, collate_fn= my_collate_func)
+ else:
+ train_dataset = AnymateDataset(name=args.trainset, root=args.root)
+ train_loader = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, collate_fn= my_collate_func)
+
+ if rank == 0:
+ test_loader = DataLoader(AnymateDataset(name=args.testset, root=args.root), batch_size=args.test_batch, shuffle=False, collate_fn= my_collate_func )
+
+ if not args.schedule:
+ args.schedule = [args.epochs//2]
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.schedule, gamma=args.gamma)
+ # step the scheduler to the start epoch
+ for _ in range(args.start_epoch):
+ scheduler.step()
+ if rank == 0:
+ logger = SummaryWriter(log_dir=args.logdir)
+ print('start ')
+ print('test_frequency', args.test_freq)
+ print('start from epoch', args.start_epoch)
+ # start training
+ for epoch in range(args.start_epoch, args.epochs):
+ test_dict = None
+ is_best = False
+ lr = scheduler.get_last_lr()
+ if rank == 0:
+ print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr[0]))
+ train_loss, grad_norm = train(train_loader, model, optimizer, args)
+ if rank == 0 and (epoch == 0 or (epoch+1)%args.test_freq== 0):
+ print('Testing epoch', epoch+1)
+ test_dict = test(test_loader, model, args, world_size=world_size)
+
+
+ scheduler.step()
+ if rank == 0:
+ print('Epoch{:d}. train_loss: {:.6f}.'.format(epoch + 1, train_loss))
+ print('Epoch{:d}. grad_norm: {:.6f}.'.format(epoch + 1, grad_norm))
+ info = {'train_loss': train_loss, 'grad_norm': grad_norm, 'lr': lr[0]}
+ # print('Epoch{:d}. val_loss: {:.6f}.'.format(epoch + 1, val_loss))
+ if test_dict is not None:
+ for key, value in test_dict.items():
+ print('Epoch{:d}. {:s}: {:.6f}.'.format(epoch + 1, key, value))
+
+ test_loss = test_dict['test loss'] if not args.mode == 'diffusion' else test_dict['chamfer']
+ is_best = test_loss < lowest_loss
+ lowest_loss = min(test_loss, lowest_loss)
+ for key, value in test_dict.items():
+ info[key] = value
+
+ for tag, value in info.items():
+ logger.add_scalar(tag, value, epoch+1)
+ save_dict = {'epoch': epoch + 1, 'state_dict': model.module.state_dict(), 'lowest_loss': lowest_loss, 'optimizer': optimizer.state_dict(), 'model_config': model_config}
+ save_checkpoint(save_dict, is_best=is_best, checkpoint=args.checkpoint, snapshot=args.epochs//20)
+
+def get_criterion(args):
+ if args.loss == 'cos':
+ criterion = cos_loss
+ elif args.loss == 'ce':
+ criterion = cross_entropy_with_probs_batch
+ elif args.loss == 'cos_clamp':
+ criterion = cos_loss_clamp
+ else:
+ criterion = chamfer_distance_with_average
+ return criterion
+
+def get_train_loss(model, data, args):
+ criterion = get_criterion(args)
+ loss = 0.0
+ if args.mode == 'skin':
+ y_pred, idx = model(data, downsample=1024)
+ y_pred = torch.softmax(y_pred, dim=-1)
+ y = data['skins'].to(args.device)
+ y = y[:, idx]
+ loss = criterion(y_pred, y)
+
+ elif args.mode == 'conn':
+ y_pred = model(data, args.device)
+ y_pred = torch.softmax(y_pred, dim=-1)
+ y = data['conns'].to(args.device)
+ y = y[:, :y_pred.shape[1], :y_pred.shape[1]].float()
+ loss = criterion(y_pred, y)
+
+ elif args.mode == 'joints': # joints mode
+ if args.decoder == 'transformer_latent':
+ y_pred = model(data, args.device)
+ joints_gt = data['joints'].to(args.device)
+ loss = 0.0
+ for i in range(joints_gt.shape[0]):
+ joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
+ loss += criterion(y_pred[i:i+1], joints_gt_i.unsqueeze(0))
+ loss /= joints_gt.shape[0]
+
+ elif args.decoder == 'triplane' or args.decoder == 'implicit_transformer':
+ criterion = torch.nn.BCEWithLogitsLoss()
+ y_pred = model(data, args.device, downsample=True)
+ joints_gt = data['joints'].to(args.device)
+ for i in range(joints_gt.shape[0]):
+ joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
+ vol = get_co(data['vox'][i])
+ if data['vox'][i].shape[0] > 50000:
+ vol = vol[y_pred[i][1]]
+ gt = get_gt(vol.to(args.device), joints_gt_i)
+ loss += criterion(y_pred[i][0].squeeze(-1).unsqueeze(0), gt.unsqueeze(0))
+ else:
+ gt = get_gt(vol.to(args.device), joints_gt_i)
+ loss += criterion(y_pred[i].squeeze(-1).unsqueeze(0), gt.unsqueeze(0))
+ loss /= joints_gt.shape[0]
+
+ elif args.mode == 'diffusion':
+ noise_scheduler = DDIMScheduler(num_train_timesteps=args.num_train_step)
+
+ samples = data['joints_repeat'].to(model.device).float()
+ #use 256 input joints
+ samples = samples[...,:args.num_training_points,:]
+
+ samples = samples.to(model.device)
+ noise = torch.randn(samples.shape, device=samples.device)
+ assert samples.device == noise.device
+ bs = samples.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bs,), device=samples.device,
+ dtype=torch.int64
+ )
+
+ noisy_joints = noise_scheduler.add_noise(samples, noise, timesteps)
+ noisy_joints = noisy_joints.to(model.device)
+ noisy_joints = noisy_joints.permute(0, 2, 1)
+
+ noise_pred = model(data, noisy_joints=noisy_joints, timesteps = timesteps)
+ noise_pred = noise_pred.permute(0, 2, 1)
+ loss = F.mse_loss(noise_pred, noise)
+
+ return loss
+
+def train(train_loader, model, optimizer, args):
+ if not args.finetune:
+ model.train()
+ model.module.encoder.eval()
+ else:
+ model.train()
+ loss_meter = AverageMeter()
+ grad_norm_meter = AverageMeter()
+
+ for data in tqdm(train_loader):
+ loss = get_train_loss(model, data, args)
+ optimizer.zero_grad()
+ loss.backward()
+ grad_norm = 0
+
+ for p in optimizer.param_groups[0]['params']:
+ grad_norm += p.grad.data.norm(2).item()
+ grad_norm_meter.update(grad_norm)
+ optimizer.step()
+ loss_meter.update(loss.item())
+
+ return loss_meter.avg, grad_norm_meter.avg
+
+def test(test_loader, model, args, world_size=1):
+ model.eval()
+ assert args.mode in ['skin', 'joints', 'conn', 'diffusion'], 'mode should be choose from [skin, joints, conn, diffusion], got {}'.format(args.mode)
+
+ if args.mode == 'skin' or args.mode == 'conn':
+ loss_meter = AverageMeter()
+ cos_sim_meter = AverageMeter()
+ cos_clamp_meter = AverageMeter()
+ for i, data in enumerate(tqdm(test_loader)):
+ if world_size > 1 and i > 1000:
+ break
+ with torch.no_grad():
+ y_pred = model(data, args.device)
+ y_pred = torch.softmax(y_pred, dim=-1)
+
+ if args.mode == 'skin':
+ y = data['skins'].to(args.device)
+ elif args.mode == 'conn':
+ y = data['conns'].to(args.device)
+ y = y[:, :y_pred.shape[1], :y_pred.shape[1]].float()
+
+ loss = 0.0
+ loss = cross_entropy_with_probs_batch(y_pred, y)
+ loss_meter.update(loss.item())
+ cos_sim = cos_loss(y_pred, y)
+ cos_sim_meter.update(cos_sim.mean().item()) # 1 - loss.item()
+ cos_clamp = cos_loss_clamp(y_pred, y)
+ cos_clamp_meter.update(cos_clamp.mean().item())
+
+ loss_dict = {'test loss': loss_meter.avg, 'cos_sim': cos_sim_meter.avg, 'cos_clamp': cos_clamp_meter.avg}
+ # get the loss of the joints prediction
+ elif args.mode == 'joints':
+ if args.decoder == 'transformer_latent':
+ loss_meter = AverageMeter()
+ emd_meter = AverageMeter()
+ for i, data in tqdm(enumerate(test_loader)):
+ if world_size > 1 and i > 1000:
+ break
+ with torch.no_grad():
+ y_pred = model(data, args.device)
+ joints_gt = data['joints'].to(args.device)
+
+ loss = 0.0
+ emd = 0.0
+ for i in range(joints_gt.shape[0]):
+ joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
+ y_pred_i = y_pred[i]
+
+ y_pred_i = y_pred[i].detach().cpu().numpy()
+ clustering = DBSCAN(eps=0.03, min_samples=1).fit(y_pred_i) # Consider add eps and min_samples as arguments
+ cluster_centers = []
+ for cluster in set(clustering.labels_):
+ cluster_centers.append(y_pred_i[clustering.labels_ == cluster].mean(axis=0))
+ y_pred_i = torch.from_numpy(np.array(cluster_centers)).to(args.device)
+
+ if y_pred_i.shape[0] < 2:
+ print(data['name'][i] + ' has less than 2 points')
+ continue
+ loss += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joints_gt_i.unsqueeze(0))
+ emd_i, pi = pcu.earth_movers_distance(y_pred_i.cpu().numpy().astype(np.float64), joints_gt_i.cpu().numpy().astype(np.float64))
+ emd += emd_i
+ if loss == 0 or emd == 0:
+ continue
+ loss /= joints_gt.shape[0]
+ loss_meter.update(loss.item())
+ emd_meter.update(emd)
+ loss_dict = {'test loss': loss_meter.avg, 'emd': emd_meter.avg}
+
+ elif args.decoder == 'triplane' or 'implicit_transformer':
+ loss_meter = AverageMeter()
+ emd_meter = AverageMeter()
+ chamfer_meter = AverageMeter()
+ criterion = torch.nn.BCEWithLogitsLoss()
+ for data in tqdm(test_loader):
+ with torch.no_grad():
+ y_pred = model(data, args.device)
+ joints_gt = data['joints'].to(args.device)
+ loss = 0.0
+ emd = 0.0
+ chamfer = 0.0
+ for i in range(joints_gt.shape[0]):
+ joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
+ vol = get_co(data['vox'][i])
+ gt = get_gt(vol.to(args.device), joints_gt_i)
+ loss += criterion(y_pred[i].squeeze(-1).unsqueeze(0), gt.unsqueeze(0))
+ key_points = extract_keypoints(y_pred[i].cpu(), data['vox'][i])
+ if len(key_points) < 2:
+ continue
+ key_points = key_points / 32 - 1
+ chamfer += chamfer_distance_with_average(torch.from_numpy(key_points).unsqueeze(0).to(joints_gt_i.device), joints_gt_i.unsqueeze(0))
+ emd_i, _ = pcu.earth_movers_distance(key_points.astype(np.float64), joints_gt_i.cpu().numpy().astype(np.float64))
+ emd += emd_i
+ if loss == 0 or emd == 0 or chamfer == 0:
+ continue
+ loss /= joints_gt.shape[0]
+ loss_meter.update(loss.item())
+ emd_meter.update(emd)
+ chamfer_meter.update(chamfer.item())
+ loss_dict = {'test loss': loss_meter.avg, 'emd': emd_meter.avg, 'chamfer': chamfer_meter.avg}
+
+ elif args.mode == 'diffusion':
+ loss_meter = AverageMeter()
+ emd_meter = AverageMeter()
+ chamfer_meter = AverageMeter()
+ generator=torch.Generator(device='cpu').manual_seed(args.seed+1)
+ scheduler = DDIMScheduler(num_train_timesteps=args.num_train_step)
+ scheduler.set_timesteps(args.num_train_step)
+ points_shape = [args.test_batch, args.num_training_points, 3]
+ for data in tqdm(test_loader):
+ joints_gt = data['joints'].to(dtype=torch.float64)
+ points_noise = randn_tensor(points_shape, generator=generator)
+ points = points_noise.permute(0, 2, 1).to(model.device)
+ for t in scheduler.timesteps:
+ with torch.no_grad():
+ time_steps = torch.ones(args.test_batch, 1, dtype=torch.long) * t
+ time_steps = time_steps.to(model.device)
+ model_output = model(data, noisy_joints=points, timesteps = time_steps)
+
+ points = scheduler.step(model_output, t, points, generator=generator).prev_sample
+ points = points.permute(0, 2, 1).cpu()
+
+ chamfer_sum = 0.0
+ emd_sum = 0.0
+
+ for i in range(args.test_batch):
+ joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
+ points_i = points[i]
+ points_i = points_i.reshape( -1, 3)
+ emd, p = pcu.earth_movers_distance(points_i.cpu().numpy(),joints_gt_i[:,:3].cpu().numpy())
+ emd_sum += emd
+ chamfer_sum += chamfer_distance_with_average(points_i.unsqueeze(0),joints_gt_i[:,:3].unsqueeze(0))
+
+ emd_meter.update(emd_sum)
+ chamfer_meter.update(chamfer_sum.item())
+ loss_dict = {'chamfer': chamfer_meter.avg, 'emd': emd_meter.avg}
+
+ return loss_dict
diff --git a/Anymate/utils/ui_utils.py b/Anymate/utils/ui_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..341ae52f6f828b950dddda587980b234e14db716
--- /dev/null
+++ b/Anymate/utils/ui_utils.py
@@ -0,0 +1,284 @@
+import trimesh
+import numpy as np
+import torch
+import os
+import matplotlib.pyplot as plt
+import gradio as gr
+import time
+bone_colors = plt.get_cmap('tab10')
+
+from Anymate.utils.utils import load_checkpoint, get_joint, get_connectivity, get_skinning
+from Anymate.utils.dataset_utils import obj2mesh
+from Anymate.args import anymate_args
+# from Anymate.utils.render_utils import empty, add_co, add_mesh, add_joint, add_conn, add_skin, setup_armature
+
+def visualize_results(mesh_file=None, joints=None, conns=None, skins=None):
+ # Create a scene with both original and processed meshes
+ scene = trimesh.Scene()
+ vis_file = mesh_file.replace('object.obj', 'vis.glb')
+
+ if mesh_file is not None:
+ # Load the original mesh (in blue) with transparency
+ # original_mesh = trimesh.load(mesh_file)
+ original_mesh = obj2mesh(mesh_file)
+ if skins is not None:
+ # pdb.set_trace()
+ # Get per-vertex colors based on skinning weights
+ vertex_colors = np.zeros((len(original_mesh.vertices), 4))
+
+ # Convert skinning weights to numpy if needed
+ if isinstance(skins, torch.Tensor):
+ skins = skins.cpu().numpy()
+
+ # For each bone, blend colors based on skinning weights
+ for bone_idx in range(skins.shape[1]):
+ bone_color = np.array(bone_colors(bone_idx % 10)) # Get base color for this bone
+ weights = skins[:, bone_idx]
+ vertex_colors += np.outer(weights, bone_color) # Blend weighted colors
+
+ # Normalize and clip colors
+ vertex_colors = np.clip(vertex_colors, 0, 1)
+
+ # Convert to vertex colors and set alpha
+ vertex_colors = (vertex_colors * 255).astype(np.uint8)
+ vertex_colors[:, 3] = 255 # Set alpha to 100 for transparency
+ # print(vertex_colors.shape)
+ # print(vertex_colors.max(axis=0), vertex_colors.min(axis=0), vertex_colors.mean(axis=0))
+
+ # Apply colors directly to vertices
+ original_mesh.visual.vertex_colors = vertex_colors
+
+ # face_colors = np.zeros((len(original_mesh.faces), 4))
+
+ # processed_mesh = trimesh.load(mesh_file)
+ processed_mesh = obj2mesh(mesh_file)
+ # Assign vertex colors from original_mesh to processed_mesh
+ # Since they might have different number of vertices, we need to find closest vertices
+
+ # Get vertices from both meshes
+ orig_vertices = original_mesh.vertices
+ proc_vertices = processed_mesh.vertices
+
+ # For each vertex in processed_mesh, find the closest vertex in original_mesh
+ closest_indices = []
+ for proc_vertex in proc_vertices:
+ # Calculate distances to all original vertices
+ distances = np.linalg.norm(orig_vertices - proc_vertex, axis=1)
+ # Find index of closest vertex
+ closest_idx = np.argmin(distances)
+ closest_indices.append(closest_idx)
+
+ proc_vertex_colors = original_mesh.visual.vertex_colors[closest_indices]
+ processed_mesh.visual.vertex_colors = proc_vertex_colors
+ original_mesh = processed_mesh
+
+ else:
+ original_mesh.visual.face_colors = [255, 255, 255, 100] # Blue with alpha=100 for transparency
+ scene.add_geometry(original_mesh)
+
+ if joints is not None:
+ # create a sphere for each joint
+ for position in joints:
+ sphere = trimesh.primitives.Sphere(radius=0.02)
+ sphere.visual.face_colors = [255, 0, 0, 255] # Red with transparency
+ sphere.apply_translation(position.cpu().numpy())
+ scene.add_geometry(sphere)
+
+ if conns is not None:
+ # create a line for each connectivity
+ for i, conn in enumerate(conns):
+ if i == conn:
+ continue
+ # Create cylinder between joints
+ points = [joints[i].cpu().numpy(), joints[conn].cpu().numpy()]
+ direction = points[1] - points[0]
+ height = np.linalg.norm(direction)
+ cylinder = trimesh.primitives.Cylinder(radius=0.01, height=height)
+
+ # Calculate rotation matrix to align cylinder with direction
+ direction = direction / height # Normalize direction vector
+ up_vector = np.array([0, 0, 1])
+ rotation_matrix = trimesh.geometry.align_vectors(up_vector, direction)
+
+ # Apply rotation and translation to cylinder
+ cylinder.apply_transform(rotation_matrix)
+ cylinder.apply_translation(points[0] + direction * height/2)
+
+ cylinder.visual.face_colors = [0, 0, 255, 255] # Blue
+ scene.add_geometry(cylinder)
+
+ # Export the scene
+ scene.export(vis_file)
+ return vis_file
+
+
+def process_mesh_to_pc(obj_path, sample_num = 8192, save_path = None):
+ # mesh_list : list of trimesh
+ try :
+ mesh = trimesh.load_mesh(obj_path)
+
+ points, face_idx = mesh.sample(sample_num, return_index=True)
+ normals = mesh.face_normals[face_idx]
+
+ pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
+
+
+ if save_path is not None:
+ np.save(save_path, pc_normal)
+
+ return pc_normal
+ except Exception as e:
+ print(f"Error: {obj_path} {e}")
+ return None
+
+
+def normalize_mesh(mesh):
+ # Check if input is a scene with multiple meshes
+ if isinstance(mesh, trimesh.Scene):
+ # Combine all meshes in the scene into a single mesh
+ meshes = []
+ for geometry in mesh.geometry.values():
+ if isinstance(geometry, trimesh.Trimesh):
+ # Transform mesh to scene coordinates
+ transform = mesh.graph[mesh.graph.nodes_geometry[0]][0]
+ geometry.apply_transform(transform)
+ meshes.append(geometry)
+
+ # Combine all meshes
+ mesh = trimesh.util.concatenate(meshes)
+
+ # Get vertices and compute bounding box
+ vertices = mesh.vertices
+ bbox_min = vertices.min(axis=0)
+ bbox_max = vertices.max(axis=0)
+
+ # Find center and scale
+ center = (bbox_min + bbox_max) * 0.5
+ scale = 2.0 / (bbox_max - bbox_min).max()
+
+ # Center and scale vertices
+ vertices = (vertices - center) * scale
+
+ # Create new mesh with normalized vertices
+ normalized_mesh = trimesh.Trimesh(vertices=vertices,
+ faces=mesh.faces,
+ face_normals=mesh.face_normals,
+ vertex_normals=mesh.vertex_normals,
+ process=False)
+
+ # # Copy texture from original mesh if it exists
+ # if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'material'):
+ # print("copy material")
+ # normalized_mesh.visual.material = mesh.visual.material
+ # if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'texture'):
+ # print("copy texture")
+ # normalized_mesh.visual.texture = mesh.visual.texture
+ # if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv'):
+ # print("copy uv")
+ # normalized_mesh.visual.uv = mesh.visual.uv
+
+ return normalized_mesh
+
+
+def vis_joint(normalized_mesh_file, joints):
+ if normalized_mesh_file is None or joints is None:
+ return None, None
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints)
+ return vis_file, vis_file
+
+def vis_connectivity(normalized_mesh_file, joints, conns):
+ if normalized_mesh_file is None or joints is None or conns is None:
+ return None, None
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, conns=conns)
+ return vis_file, vis_file
+
+def vis_skinning(normalized_mesh_file, joints, conns, skins):
+ if normalized_mesh_file is None or joints is None or conns is None or skins is None:
+ return None, None
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, conns=conns, skins=skins)
+ return vis_file, vis_file
+
+def prepare_blender_file(normalized_mesh_file):
+ if normalized_mesh_file is None:
+ return None
+
+ if not os.path.exists(normalized_mesh_file) or not os.path.exists(normalized_mesh_file.replace('object.obj', 'joints.pt')) or not os.path.exists(normalized_mesh_file.replace('object.obj', 'conns.pt')) or not os.path.exists(normalized_mesh_file.replace('object.obj', 'skins.pt')):
+ return None
+
+ folder = normalized_mesh_file.replace('object.obj', '')
+ abs_folder = os.path.abspath(folder)
+ os.system(f"python Render.py --path {abs_folder}")
+
+ blender_file = os.path.join(folder, 'blender_output.blend')
+ while not os.path.exists(blender_file):
+ time.sleep(1)
+
+ return blender_file
+
+
+def process_input(mesh_file):
+ """
+ Function to handle input changes and initialize visualization
+
+ Args:
+ mesh_file: Path to input mesh file
+ joint_checkpoint: Path to joint prediction checkpoint
+ conn_checkpoint: Path to connectivity prediction checkpoint
+ skin_checkpoint: Path to skinning prediction checkpoint
+
+ Returns:
+ vis_file: Path to visualization file
+ """
+
+ # For now just visualize the input mesh
+ if mesh_file is None:
+ return None, None, None, None, None, None, None, None
+
+ # make folder for tmp files
+ os.makedirs(f"Anymate/tmp/{mesh_file.split('/')[-1].replace('.obj', '')}", exist_ok=True)
+
+ normalized_mesh = normalize_mesh(obj2mesh(mesh_file))
+ normalized_mesh_file = f"Anymate/tmp/{mesh_file.split('/')[-1].replace('.obj', '')}/object.obj"
+ normalized_mesh.export(normalized_mesh_file)
+
+ # normalized_mesh.export(mesh_file)
+
+ vis_file = visualize_results(mesh_file=normalized_mesh_file)
+ pc = process_mesh_to_pc(normalized_mesh_file)
+ pc = torch.from_numpy(pc).to(anymate_args.device).to(torch.float32)
+
+ # print(pc.shape, pc.max(dim=0), pc.min(dim=0))
+
+ return normalized_mesh_file, vis_file, vis_file, None, pc, None, None, None
+
+
+def get_model(checkpoint):
+ model = load_checkpoint(checkpoint, anymate_args.device, anymate_args.num_joints)
+ return model, True
+
+def get_result_joint(mesh_file, model, pc, eps=0.03, min_samples=1):
+ return get_joint(pc, model, device=anymate_args.device, save=mesh_file.replace('object.obj', 'joints.pt'), eps=eps, min_samples=min_samples)
+
+def get_result_connectivity(mesh_file, model, pc, joints):
+ return get_connectivity(pc, joints, model, device=anymate_args.device, save=mesh_file.replace('object.obj', 'conns.pt'))
+
+def get_result_skinning(mesh_file, model, pc, joints, conns):
+ # mesh = trimesh.load(mesh_file)
+ mesh = obj2mesh(mesh_file)
+ vertices = torch.from_numpy(mesh.vertices).to(anymate_args.device).to(torch.float32)
+ vertex_normals = torch.from_numpy(mesh.vertex_normals).to(anymate_args.device).to(torch.float32)
+ vertices = torch.cat([vertices, vertex_normals], dim=-1)
+ return get_skinning(pc, joints, conns, model, vertices=vertices, device=anymate_args.device, save=mesh_file.replace('object.obj', 'skins.pt'))
+
+def get_all_models(checkpoint_joint, checkpoint_conn, checkpoint_skin):
+ model_joint = load_checkpoint(checkpoint_joint, anymate_args.device, anymate_args.num_joints)
+ model_connectivity = load_checkpoint(checkpoint_conn, anymate_args.device, anymate_args.num_joints)
+ model_skin = load_checkpoint(checkpoint_skin, anymate_args.device, anymate_args.num_joints)
+ return model_joint, model_connectivity, model_skin, True, True, True
+
+def get_all_results(mesh_file, model_joint, model_connectivity, model_skin, pc, eps=0.03, min_samples=1):
+ joints = get_result_joint(mesh_file, model_joint, pc, eps=eps, min_samples=min_samples)
+ conns = get_result_connectivity(mesh_file, model_connectivity, pc, joints)
+ skins = get_result_skinning(mesh_file, model_skin, pc, joints, conns)
+ return joints, conns, skins
+
diff --git a/Anymate/utils/ui_utils_bpy.py b/Anymate/utils/ui_utils_bpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..943b7fc372571f348f3044ebba4b39772a34cf92
--- /dev/null
+++ b/Anymate/utils/ui_utils_bpy.py
@@ -0,0 +1,134 @@
+import trimesh
+import numpy as np
+import torch
+
+from Anymate.utils.utils import load_checkpoint, get_joints, get_connectivity
+from Anymate.args import anymate_args
+from Anymate.utils.render_utils import empty, add_co, add_mesh, add_joints, add_conn, add_skin, setup_armature, save_scene
+
+def visualize_results(mesh_file=None, joints=None, connectivity=None, skinning=None):
+
+ import bpy
+ # Create a scene with both original and processed meshes
+ vis_file = "Anymate/tmp/vis_scene.glb"
+ print('fffffffff')
+
+ # empty()
+ bpy.ops.wm.read_homefile(use_empty=True)
+
+ if mesh_file is not None:
+ # add_mesh(mesh_file)
+ bpy.ops.wm.obj_import(filepath=mesh_file)
+
+ if joints is not None:
+ add_joints(joints)
+
+ if connectivity is not None:
+ add_conn(connectivity, joints)
+
+ if skinning is not None:
+ add_skin(mesh_file, skinning)
+
+ # setup_armature()
+ # save_scene(vis_file)
+ bpy.ops.wm.save_as_mainfile(filepath=vis_file)
+ return vis_file
+
+
+def process_mesh_to_pc(obj_path, sample_num = 8192, save_path = None):
+ # mesh_list : list of trimesh
+ try :
+ mesh = trimesh.load_mesh(obj_path)
+
+ points, face_idx = mesh.sample(sample_num, return_index=True)
+ normals = mesh.face_normals[face_idx]
+
+ pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
+
+
+ if save_path is not None:
+ np.save(save_path, pc_normal)
+
+ return pc_normal
+ except Exception as e:
+ print(f"Error: {obj_path} {e}")
+ return None
+
+
+def normalize_mesh(mesh):
+ # Get vertices and compute bounding box
+ vertices = mesh.vertices
+ bbox_min = vertices.min(axis=0)
+ bbox_max = vertices.max(axis=0)
+
+ # Find center and scale
+ center = (bbox_min + bbox_max) * 0.5
+ scale = 2.0 / (bbox_max - bbox_min).max()
+
+ # Center and scale vertices
+ vertices = (vertices - center) * scale
+
+ # Create new mesh with normalized vertices
+ normalized_mesh = trimesh.Trimesh(vertices=vertices,
+ faces=mesh.faces,
+ face_normals=mesh.face_normals,
+ vertex_normals=mesh.vertex_normals)
+
+ return normalized_mesh
+
+
+def vis_joint(normalized_mesh_file, joints):
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints)
+ return vis_file
+
+def vis_connectivity(normalized_mesh_file, joints, connectivity):
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, connectivity=connectivity)
+ return vis_file
+
+def vis_skinning(skinning):
+ vis_file = visualize_results(skinning=skinning)
+ return vis_file
+
+
+def process_input(mesh_file):
+ """
+ Function to handle input changes and initialize visualization
+
+ Args:
+ mesh_file: Path to input mesh file
+ joint_checkpoint: Path to joint prediction checkpoint
+ conn_checkpoint: Path to connectivity prediction checkpoint
+ skin_checkpoint: Path to skinning prediction checkpoint
+
+ Returns:
+ vis_file: Path to visualization file
+ """
+
+ # For now just visualize the input mesh
+
+ normalized_mesh = normalize_mesh(trimesh.load(mesh_file))
+ normalized_mesh_file = "Anymate/tmp/normalized_mesh.obj"
+ normalized_mesh.export(normalized_mesh_file)
+ vis_file = visualize_results(mesh_file=normalized_mesh_file)
+ pc = process_mesh_to_pc(normalized_mesh_file)
+ pc = torch.from_numpy(pc).to(anymate_args.device).to(torch.float32)
+
+ print(pc.shape, pc.max(dim=0), pc.min(dim=0))
+
+ return normalized_mesh_file, vis_file, pc, None, None, None
+
+
+def get_model(checkpoint):
+ model = load_checkpoint(checkpoint, anymate_args.device, anymate_args.num_joints)
+ return model, True
+
+def get_result_joint(model, pc):
+ return get_joints(pc, model, anymate_args.device)
+
+def get_result_connectivity(model, pc, joints):
+ return get_connectivity(pc, joints, model, anymate_args.device)
+
+def get_result_skinning(model, pc):
+ with torch.no_grad():
+ skinning = model(pc)
+ return skinning
\ No newline at end of file
diff --git a/Anymate/utils/utils.py b/Anymate/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..483f899ab34fcbb8380292cfe7cf3cec9fa26aff
--- /dev/null
+++ b/Anymate/utils/utils.py
@@ -0,0 +1,77 @@
+import torch
+from Anymate.model import EncoderDecoder
+from sklearn.cluster import DBSCAN
+
+def load_checkpoint(path, device, num_joints):
+ print(f"Loading model from {path}")
+ model_state = torch.load(path)
+ model_weights = model_state['state_dict']
+
+ try:
+ model_config = model_state['model_config']
+ model = EncoderDecoder(device=device, dtype=torch.float32, **model_config)
+ model.to(device)
+ model.load_state_dict(model_weights, strict=True)
+ except:
+ encoder = path.split('/')[-1].split('.')[0].split('-')[0]
+ decoder = path.split('/')[-1].split('.')[0].split('-')[1]
+ model = EncoderDecoder(encoder=encoder, decoder=decoder, device=device, dtype=torch.float32, num_joints=num_joints)
+ model.to(device)
+ model.load_state_dict(model_weights, strict=True)
+
+ print(f"Loaded model from {path}")
+
+ return model
+
+def get_joint(pc, model, device='cuda', save=None, vox=None, eps=0.03, min_samples=1):
+ model.eval()
+ data = {'points_cloud': pc.unsqueeze(0)}
+ if vox is not None:
+ data['vox'] = vox.unsqueeze(0)
+ with torch.no_grad():
+ model.decoder.inference_mode(eps=eps, min_samples=min_samples)
+ joints = model(data, device=device)
+ joints = torch.tensor(joints, dtype=torch.float32).to(device)
+
+ if save is not None:
+ torch.save(joints, save)
+
+ return joints
+
+def get_connectivity(pc, joints, model, device='cuda',return_prob=False, save=None):
+ model.eval()
+ data = {'points_cloud': pc.unsqueeze(0), 'joints': joints.unsqueeze(0), 'joints_num': torch.tensor([joints.shape[0]]),
+ 'joints_mask': torch.ones(joints.shape[0], device=device).unsqueeze(0)}
+ with torch.no_grad():
+ conns = model(data, device=device).softmax(dim=-1)
+ conns = conns.squeeze(0) if return_prob else torch.argmax(conns, dim=-1).squeeze(0)
+
+ if save is not None:
+ torch.save(conns, save)
+
+ return conns
+
+def get_skinning(pc, joints, conns, model, vertices=None, bones=None, device='cuda', save=None):
+ model.eval()
+
+ if bones is None:
+ bones = []
+ for i in range(joints.shape[0]):
+ if conns[i] != i:
+ bones.append(torch.cat((joints[conns[i]], joints[i]), dim=-1))
+ bones = torch.stack(bones, dim=0)
+
+ data = {'points_cloud': pc.unsqueeze(0), 'bones': bones.unsqueeze(0), 'bones_num': torch.tensor([bones.shape[0]]),
+ 'bones_mask': torch.ones(bones.shape[0], device=device).unsqueeze(0)}
+
+ if vertices is not None:
+ data['vertices'] = vertices.unsqueeze(0)
+ model.decoder.inference = True
+
+ with torch.no_grad():
+ skins = model(data, device=device).softmax(dim=-1).squeeze(0)
+
+ if save is not None:
+ torch.save(skins, save)
+
+ return skins
diff --git a/Anymate/utils/vol_utils.py b/Anymate/utils/vol_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f41ffc979dd381456df71b8c23ed763ea70e9ed1
--- /dev/null
+++ b/Anymate/utils/vol_utils.py
@@ -0,0 +1,135 @@
+import numpy as np
+import torch
+from ThirdParty.michelangelo.graphics.primitives import generate_dense_grid_points
+from sklearn.cluster import DBSCAN
+
+def get_vol(bounds=(-0.5, 0.0, -0.5, 0.5, 1.0, 0.5), octree_depth=6):
+
+ bbox_min = np.array(bounds[0:3])
+ bbox_max = np.array(bounds[3:6])
+ bbox_size = bbox_max - bbox_min
+
+ xyz_samples, grid_size, length = generate_dense_grid_points(
+ bbox_min=bbox_min,
+ bbox_max=bbox_max,
+ octree_depth=octree_depth,
+ indexing="ij"
+ )
+ xyz_samples = torch.FloatTensor(xyz_samples) # ((2^d)+1)^3
+
+ return xyz_samples
+
+def get_co(vox, bounds=(-1.0, -1.0, -1.0, 1.0, 1.0, 1.0), dtype = torch.float32):
+
+ bbox_min = torch.tensor(bounds[0:3]).to(vox.device)
+ bbox_max = torch.tensor(bounds[3:6]).to(vox.device)
+ bbox_size = bbox_max - bbox_min
+
+ # ind = torch.argwhere(vox)
+ # ind = ind.to(dtype) / (vox.shape[0]) * bbox_size + bbox_min
+ ind = vox
+ ind = ind.to(dtype) / 64 * bbox_size + bbox_min
+
+ return ind.to(dtype)
+
+def get_gt(vol, joints, octree_depth=6):
+ sigma = 2 / 2**octree_depth
+
+ dist = torch.cdist(vol, joints)
+ # print(dist.min(), dist.max())
+
+ dist = dist.min(dim=1).values
+
+ gt = torch.exp(-dist**2 / 2 / sigma**2)
+
+ return gt
+
+def project_onto_planes(planes, coordinates):
+ """
+ Does a projection of a 3D point onto a batch of 2D planes,
+ returning 2D plane coordinates.
+
+ Takes plane axes of shape n_planes, 3, 3
+ # Takes coordinates of shape N, M, 3
+ # returns projections of shape N*n_planes, M, 2
+ """
+ N, M, C = coordinates.shape
+ n_planes, _, _ = planes.shape
+ coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
+ inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
+ projections = torch.bmm(coordinates, inv_planes)
+ return projections[..., :2]
+
+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
+ assert padding_mode == 'zeros'
+ N, n_planes, C, H, W = plane_features.shape
+ _, M, _ = coordinates.shape
+ plane_features = plane_features.view(N*n_planes, C, H, W)
+
+ # coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds
+
+ projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
+ output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
+ return output_features
+
+def generate_planes():
+ """
+ Defines planes by the three vectors that form the "axes" of the
+ plane. Should work with arbitrary number of planes and planes of
+ arbitrary orientation.
+ """
+ return torch.tensor([[[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]],
+ [[1, 0, 0],
+ [0, 0, 1],
+ [0, 1, 0]],
+ [[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]]], dtype=torch.float32)
+
+def extract_keypoints(y_pred, vox):
+
+ y_pred = y_pred.detach().cpu().numpy()
+ vox = vox.detach().cpu().numpy()
+ volume = np.zeros([64, 64, 64])
+ volume[...] = -100
+ volume[vox[:, 0], vox[:, 1], vox[:, 2]] = y_pred.squeeze(-1)
+
+ clusters = []
+ cluster_model = DBSCAN(eps=1.8, min_samples=1)
+
+ level = min((0.85 * y_pred.max() + 0.15 * y_pred.min()).item(), 0)
+ potential_points = np.argwhere(volume >= level)
+ clustering = cluster_model.fit(potential_points)
+ for cluster in set(clustering.labels_):
+ if cluster == -1:
+ print('got noise', len(potential_points[clustering.labels_ == cluster]))
+ continue
+ clusters.append(potential_points[clustering.labels_ == cluster])
+
+ while True:
+ if np.all(np.array([(len(cluster) < 10) for cluster in clusters])):
+ break
+ new_clusters = []
+ for points in clusters:
+ if len(points) < 10:
+ new_clusters.append(points)
+ continue
+
+ value = volume[points[:, 0], points[:, 1], points[:, 2]]
+
+ potential_points = points[value >= (0.1 * value.max() + 0.9 * value.min())]
+ clustering = cluster_model.fit(potential_points)
+ for cluster in set(clustering.labels_):
+ if cluster == -1:
+ print('got noise', len(potential_points[clustering.labels_ == cluster]))
+ continue
+ new_clusters.append(potential_points[clustering.labels_ == cluster])
+
+ clusters = new_clusters
+
+ key_points = np.array([cluster.mean(axis=0) for cluster in clusters])
+ key_points = key_points / 32 - 1
+
+ return key_points
\ No newline at end of file
diff --git a/Render.py b/Render.py
new file mode 100644
index 0000000000000000000000000000000000000000..46d14b83fa9b374fe59a151603f290e7ddee3669
--- /dev/null
+++ b/Render.py
@@ -0,0 +1,17 @@
+import argparse
+import bpy
+import mathutils
+from Anymate.utils.render_utils import empty, setup_armature
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Anymate rendering script')
+ parser.add_argument('--path', type=str, required=True, help='Path to the model file')
+ return parser.parse_args()
+
+args = parse_args()
+
+print(f"Starting converting {args.path} to blender format...")
+
+empty()
+setup_armature(args.path)
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/.gitignore b/ThirdParty/PointLLM/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..c79e9c0f36deb1fc479ee756f91fa833a09dd820
--- /dev/null
+++ b/ThirdParty/PointLLM/.gitignore
@@ -0,0 +1,12 @@
+__pycache__
+*.egg-info
+.vscode
+checkpoints
+outputs
+wandb
+anno_data
+objaverse_data
+modelnet40_data
+*.zip
+*.ipynb
+serving_workdirs
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/README.md b/ThirdParty/PointLLM/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3315b991048449382e43939e3a7e14e313f0996e
--- /dev/null
+++ b/ThirdParty/PointLLM/README.md
@@ -0,0 +1,353 @@
+
+
+
PointLLM: Empowering Large Language Models to Understand Point Clouds
+
+ Runsen Xu
+ Xiaolong Wang
+ Tai Wang
+ Yilun Chen
+ Jiangmiao Pang*
+ Dahua Lin
+
+ The Chinese University of Hong Kong Shanghai AI Laboratory Zhejiang University
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+## 🏠 About
+
+
+

+
+We introduce PointLLM, a multi-modal large language model capable of understanding colored point clouds of objects. It perceives object types, geometric structures, and appearance without concerns for ambiguous depth, occlusion, or viewpoint dependency. We collect a novel dataset comprising 660K simple and 70K complex point-text instruction pairs to enable a two-stage training strategy. To rigorously evaluate our model's perceptual abilities and its generalization capabilities, we establish two benchmarks: Generative 3D Object Classification and 3D Object Captioning, assessed through three different evaluation methods.
+
+## 🔥 News
+- [2024-09-06] We have uploaded the camera-ready version of PointLLM for ECCV 2024, which includes clearer writing and additional experimental results. Please check the paper [here](https://arxiv.org/abs/2308.16911).
+- [2024-07-01] PointLLM has been accepted by ECCV 2024 with all "strong-accept" recommendation. 🎉 We are looking for self-motivated students to conduct research regarding PointLLM. Please send an email to runsxu@gmail.com with your CV if you are interested!
+- [2023-12-29] We release the codes of our online Gradio demo.
+- [2023-12-26] We release the codes for model evaluation, including ChatGPT/GPT-4 evaluation and traditional metric evaluation.
+- [2023-12-08] We release the codes for training and PointLLM-v1.2. The online demo has also been upgraded to the v1.2 version. Please enjoy! 🎉
+- [2023-12-01] We have released an updated version of our paper (v2), which includes additional baseline comparisons, enhanced human-evaluation metrics, improved model performance (PointLLM-v1.2), and other refinements. Please check the updated version [here](https://arxiv.org/abs/2308.16911).
+- [2023-10-18] We release our instruction-following data, including both the simple-description and complex instructions. Download [here](https://huggingface.co/datasets/RunsenXu/PointLLM).
+- [2023-09-26] We release the inferencing codes with checkpoints as well as the Objaverse colored point cloud files we use. You can chat with PointLLM with your own machines.
+- [2023-08-31] We release the [paper](http://arxiv.org/abs/2308.16911) of PointLLM and an online gradio [demo](http://101.230.144.196). Try it! 🎉
+
+
+## 📋 Contents
+- [🤖 Online Demo](#-online-demo)
+- [💬 Dialogue Examples](#-dialogue-examples)
+- [🔍 Overview](#-overview)
+- [📦 Training and Evaluation](#-training-and-evaluation)
+- [📝 TODO List](#-todo-list)
+- [🔗 Citation](#-citation)
+- [📄 License](#-license)
+- [📚 Related Work](#-related-work)
+- [👏 Acknowledgements](#-acknowledgements)
+
+## 🤖 Online Demo
+PointLLM is online! Try it at [http://101.230.144.196](http://101.230.144.196) or at [OpenXLab/PointLLM](https://openxlab.org.cn/apps/detail/openxlab-app/PointLLM).
+
+You can chat with PointLLM about the models of the [Objaverse](https://objaverse.allenai.org) dataset or about your own point clouds!
+
+Please do not hesitate to tell us if you have any feedback! 😃
+
+## 💬 Dialogue Examples
+| Dialogue 1 | Dialogue 2| Dialogue 3 | Dialogue 4
+| :-: | :-: | :-: | :-: |
+|
|
|
|
|
+
+
+## 🔍 Overview
+
+### Model
+
+
+
+The point encoder extracts features from the input point cloud and projects them to the latent space of the LLM backbone. The LLM backbone processes sequences of point tokens and text tokens, and generates the predicted tokens as the output.
+
+### Experiment Results
+#### Quantitative Comparisons with baselines.
+Please refer to our paper for more results.
+
+
+
+
+
+
+!!!Note: Traditional metrics such as BLEU-1, ROUGE-L, and METEOR tend to favor shorter responses and may not effectively capture semantic accuracy. For a detailed discussion on this, please refer to our paper. We suggest the community not solely rely on these metrics for evaluation.
+
+#### Qualitative Comparisons with baselines.
+Please refer to our paper for more results.
+
+
+
+
+## 📦 Training and Evaluation
+### Installation
+We test our codes under the following environment:
+- Ubuntu 20.04
+- NVIDIA Driver: 515.65.01
+- CUDA 11.7
+- Python 3.10.13
+- PyTorch 2.0.1
+- Transformers 4.28.0.dev(transformers.git@cae78c46)
+
+To start:
+1. Clone this repository.
+```bash
+git clone git@github.com:OpenRobotLab/PointLLM.git
+cd PointLLM
+```
+2. Install packages
+```bash
+conda create -n pointllm python=3.10 -y
+conda activate pointllm
+pip install --upgrade pip # enable PEP 660 support
+pip install -e .
+
+# * for training
+pip install ninja
+pip install flash-attn
+```
+
+### Data Preparation
+#### Objaverse Training Data
+1. Download the two compressed files of 660K Objaverse colored point clouds [here](https://huggingface.co/datasets/RunsenXu/PointLLM/tree/main). They require about 77GB of storage space.
+2. Run the following command to merge the two files into one and uncompress it. This will produce a folder named `8192_npy` containing 660K point cloud files named `{Objaverse_ID}_8192.npy`. Each file is a numpy array with dimensions (8192, 6), where the first three dimensions are `xyz` and the last three dimensions are `rgb` in [0, 1] range.
+```bash
+cat Objaverse_660K_8192_npy_split_a* > Objaverse_660K_8192_npy.tar.gz
+tar -xvf Objaverse_660K_8192_npy.tar.gz
+```
+3. In `PointLLM` folder, create a folder `data` and create a soft link to the uncompressed file in the directory.
+```bash
+cd PointLLM
+mkdir data
+ln -s /path/to/8192_npy data/objaverse_data
+```
+
+#### Instruction-Following Data
+1. In `PointLLM/data` folder, create a directory named `anno_data`.
+2. Our instruction-following data, including both the simple-description and complex instructions, can be downloaded [here](https://huggingface.co/datasets/RunsenXu/PointLLM). If you have difficulty downloading the data (e.g. network issue), please email the authors.
+- The simple-description data has 660K samples and the complex instructions have 70K samples.
+- Both training data are based on the Objaverse dataset.
+- The complex instructions are generated with GPT-4.
+3. Put the data files in the `anno_data` directory. The directory should look like this:
+```bash
+PointLLM/data/anno_data
+├── PointLLM_brief_description_660K_filtered.json
+├── PointLLM_brief_description_660K.json
+└── PointLLM_complex_instruction_70K.json
+```
+4. Note, the `PointLLM_brief_description_660K_filtered.json` is filtered from `PointLLM_brief_description_660K.json` by removing the 3000 objects we reserved as the validation set. If you want to reproduce the results in our paper, you should use the `PointLLM_brief_description_660K_filtered.json` for training. The `PointLLM_complex_instruction_70K.json` contains objects from the training set.
+5. If you want to generate the complex instructions by yourself, please refer to our paper for other details. The system prompt is at `pointllm/data/data_generation/system_prompt_gpt4_0613.txt`.
+
+#### Evaluation Data
+1. Download the referencing GT `PointLLM_brief_description_val_200_GT.json` we use for the benchmarks on Objaverse dataset [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/PointLLM_brief_description_val_200_GT.json), and put it in `PointLLM/data/anno_data`. We also provide the 3000 object ids we filter during training [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/val_object_ids_3000.txt) and their corresponding referencing GT [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/PointLLM_brief_description_val_3000_GT.json), which can be used to evaluate on all the 3000 objects.
+2. Create a directory named `modelnet40_data` in `PointLLM/data`. Download the test split of ModelNet40 point clouds `modelnet40_test_8192pts_fps.dat` [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/modelnet40_test_8192pts_fps.dat) and put it in `PointLLM/data/modelnet40_data`.
+
+### Training
+#### Download the Initial LLM and Point Encoder Weights
+1. In `PointLLM` folder, create a directory named `checkpoints`.
+2. Download the pre-trained LLM and point encoder: [
+PointLLM_7B_v1.1_init](https://huggingface.co/RunsenXu/PointLLM_7B_v1.1_init/tree/main) or [PointLLM_13B_v1.1_init](https://huggingface.co/RunsenXu/PointLLM_13B_v1.1_init/tree/main). Put them in the `checkpoints` directory.
+3. Note that the above "v1.1" means we use the Vicuna-v1.1 checkpoints, and you do **not** need to download the original LLaMA weights again.
+
+#### Start Training
+1. For stage-1 training, simply run:
+```bash
+cd PointLLM
+scripts/PointLLM_train_stage1.sh
+```
+2. After stage-1 training, start stage-2 training:
+```bash
+scripts/PointLLM_train_stage2.sh
+```
+
+#### PointLLM-v1.1 and PointLLM-v1.2
+Usually, you do not have to care about the following contents. They are only for reproducing the results in our v1 paper (PointLLM-v1.1). If you want to compare with our models or use our models for downstream tasks, please use PointLLM-v1.2 (refer to our v2 paper), which has better performance.
+
+ The following steps are for reproducing PointLLM-v1.1 (click to expand)
+
+1. PointLLM v1.1 and v1.2 use slightly different pre-trained point encoders and projectors. If you want to reproduce PointLLM v1.1, edit the `config.json` file in the directory of initial LLM and point encoder weights, for example, `vim checkpoints/PointLLM_7B_v1.1_init/config.json`.
+
+2. Change the key `"point_backbone_config_name"` to specify another point encoder config:
+ ```bash
+ # change from
+ "point_backbone_config_name": "PointTransformer_8192point_2layer" # v1.2
+ # to
+ "point_backbone_config_name": "PointTransformer_base_8192point", # v1.1
+ ```
+
+3. Edit the checkpoint path of the point encoder in `scripts/train_stage1.sh`:
+ ```bash
+ # change from
+ point_backbone_ckpt=$model_name_or_path/point_bert_v1.2.pt # v1.2
+ # to
+ point_backbone_ckpt=$model_name_or_path/point_bert_v1.1.pt # v1.1
+ ```
+
+
+### Chatting
+1. The trained model checkpoints are available [here](https://huggingface.co/RunsenXu) (including different versions of PointLLM).
+2. Run the following command to launch a chatbot using the `torch.float32` data type for chatting about 3D models of Objaverse. The model checkpoints will be downloaded automatically. You can also manually download the model checkpoints and specify their paths. Here is an example:
+```bash
+cd PointLLM
+PYTHONPATH=$PWD python pointllm/eval/PointLLM_chat.py --model_name RunsenXu/PointLLM_7B_v1.2 --data_name data/objaverse_data --torch_dtype float32
+```
+3. You can also easily modify the codes for using point clouds other than those from Objaverse, as long as the point clouds input to the model have dimensions (N, 6), where the first three dimensions are `xyz` and the last three dimensions are `rgb` (in [0, 1] range). You may sample the point clouds to have 8192 points, as our model is trained on such point clouds.
+4. The following table shows GPU requirements for different models and data types. We recommend using `torch.bfloat16` if applicable, which is used in the experiments in our paper.
+
+ | Model | Data Type | GPU Memory |
+ |:--------:|:---------:|:----------:|
+ | PointLLM-7B | torch.float16 | 14GB |
+ | PointLLM-7B | torch.float32 | 28GB |
+ | PointLLM-13B | torch.float16 | 26GB |
+ | PointLLM-13B | torch.float32 | 52GB |
+
+### Gradio Demo
+1. We provide the codes for our online Gradio demo. You can run the following commands to launch the demo locally for chatting and visualization.
+```bash
+cd PointLLM
+PYTHONPATH=$PWD python pointllm/eval/chat_gradio.py --model_name RunsenXu/PointLLM_7B_v1.2 --data_name data/objaverse_data
+```
+2. Kind remind: if you want to release the demo in public, please refer to https://www.gradio.app/guides/sharing-your-app#security-and-file-access.
+
+### Evaluation
+#### Inferencing
+1. Run the following commands to infer the results.
+2. Different commands for inferencing on different benchmarks (PointLLM_7B_v1.2 as an example):
+```bash
+cd PointLLM
+export PYTHONPATH=$PWD
+
+# Open Vocabulary Classification on Objaverse
+python pointllm/eval/eval_objaverse.py --model_name RunsenXu/PointLLM_7B_v1.2 --task_type classification --prompt_index 0 # or --prompt_index 1
+
+# Object captioning on Objaverse
+python pointllm/eval/eval_objaverse.py --model_name RunsenXu/PointLLM_7B_v1.2 --task_type captioning --prompt_index 2
+
+# Close-set Zero-shot Classification on ModelNet40
+python pointllm/eval/eval_modelnet_cls.py --model_name RunsenXu/PointLLM_7B_v1.2 --prompt_index 0 # or --prompt_index 1
+```
+3. Please check the default command-line arguments of these two scripts. You can specify different prompts, data paths, and other parameters.
+4. After inferencing, the results will be saved in `{model_name}/evaluation` as a dict with the following format:
+```bash
+{
+ "prompt": "",
+ "results": [
+ {
+ "object_id": "",
+ "ground_truth": "",
+ "model_output": "",
+ "label_name": "" # only for classification on modelnet40
+ }
+ ]
+}
+```
+
+#### ChatGPT/GPT-4 Evaluation
+1. Get your OpenAI API key at [https://platform.openai.com/api-keys](https://platform.openai.com/api-keys).
+2. Run the following commands to evaluate the model outputs in parallel with ChatGPT/GPT-4 (which cost approximately $1.5 to $2.2 USD).
+```bash
+cd PointLLM
+export PYTHONPATH=$PWD
+export OPENAI_API_KEY=sk-****
+
+# Open Vocabulary Classification on Objaverse
+python pointllm/eval/evaluator.py --results_path /path/to/model_output --model_type gpt-4-0613 --eval_type open-free-form-classification --parallel --num_workers 15
+
+# Object captioning on Objaverse
+python pointllm/eval/evaluator.py --results_path /path/to/model_output --model_type gpt-4-0613 --eval_type object-captioning --parallel --num_workers 15
+
+# Close-set Zero-shot Classification on ModelNet40
+python pointllm/eval/evaluator.py --results_path /path/to/model_output --model_type gpt-3.5-turbo-0613 --eval_type modelnet-close-set-classification --parallel --num_workers 15
+```
+3. The evaluation script supports interruption and resumption. You can interrupt the evaluation process at any time by using `Ctrl+C`. This will save the temporary results. If an error occurs during the evaluation, the script will also save the current state. You can resume the evaluation from where it left off by running the same command again.
+4. The evaluation results will be saved in `{model_name}/evaluation` as another dict.
+Some of the metrics are explained as follows:
+```bash
+"average_score": The GPT-evaluated captioning score we report in our paper.
+"accuracy": The classification accuracy we report in our paper, including random choices made by ChatGPT when model outputs are vague or ambiguous and ChatGPT outputs "INVALID".
+"clean_accuracy": The classification accuracy after removing those "INVALID" outputs.
+"total_predictions": The number of predictions.
+"correct_predictions": The number of correct predictions.
+"invalid_responses": The number of "INVALID" outputs by ChatGPT.
+
+# Some other statistics for calling OpenAI API
+"prompt_tokens": The total number of tokens of the prompts for ChatGPT/GPT-4.
+"completion_tokens": The total number of tokens of the completion results from ChatGPT/GPT-4.
+"GPT_cost": The API cost of the whole evaluation process, in US Dollars 💵.
+```
+5. Open-Step Evaluation. You can also start evaluation immediately after inferencing by passing the `--start_eval` flag and specifying the `--gpt_type`. For example:
+```bash
+python pointllm/eval/eval_objaverse.py --model_name RunsenXu/PointLLM_7B_v1.2 --task_type classification --prompt_index 0 --start_eval --gpt_type gpt-4-0613
+```
+
+#### Traditional Metric Evaluation
+1. For the object captioning task, run the following command to evaluate model outputs with traditional metrics including BLEU, ROUGE, METEOR, Sentence-BERT, and SimCSE.
+```bash
+python pointllm/eval/traditional_evaluator.py --results_path /path/to/model_captioning_output
+```
+2. Note that we recommend not using BLEU, ROUGE, and METEOR for evaluation as they favor short captions and fall short of capturing semantic accuracy and diversity.
+
+## 📝 TODO List
+- [x] Add inferencing codes with checkpoints.
+- [x] Release instruction-following data.
+- [x] Add training codes.
+- [x] Add evaluation codes.
+- [x] Add gradio demo codes.
+
+Community contributions are welcome!👇 If you need any support, please feel free to open an issue or contact us.
+- [ ] Support Phi-2 LLM to make PointLLM more accessible to the community.
+- [ ] Support Chinese LLMs like InternLM.
+
+## 🔗 Citation
+
+If you find our work and this codebase helpful, please consider starring this repo 🌟 and cite:
+
+```bibtex
+@article{xu2023pointllm,
+ title={PointLLM: Empowering Large Language Models to Understand Point Clouds},
+ author={Xu, Runsen and Wang, Xiaolong and Wang, Tai and Chen, Yilun and Pang, Jiangmiao and Lin, Dahua},
+ journal={arXiv preprint arXiv:2308.16911},
+ year={2023}
+}
+```
+
+## 📄 License
+
+
+This work is under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
+
+## 📚 Related Work
+Together, Let's make LLM for 3D great!
+- [Point-Bind & Point-LLM](https://arxiv.org/abs/2309.00615): aligns point clouds with Image-Bind, and leverages ImageBind-LLM to reason multi-modality input without 3D-instruction data training.
+- [3D-LLM](https://arxiv.org/abs/2307.12981): employs 2D foundation models to encode multi-view images of 3D point clouds.
+
+
+## 👏 Acknowledgements
+- [LLaVA](https://github.com/haotian-liu/LLaVA): Our codebase is built upon LLaVA.
+- [Vicuna](https://github.com/lm-sys/FastChat): We use the Vicuna-7B and Vicuna-13B checkpoints.
+- [Objaverse](https://objaverse.allenai.org): We use models of the Objaverse dataset for training and evaluation.
+- [Cap3D](https://github.com/crockwell/Cap3D/): We use the Cap3D captioning data for our data generation.
+- [ULIP-2](https://github.com/salesforce/ULIP): We use ULIP-2 for pre-training our point cloud encoder.
diff --git a/ThirdParty/PointLLM/__init__.py b/ThirdParty/PointLLM/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ThirdParty/PointLLM/pointllm/__init__.py b/ThirdParty/PointLLM/pointllm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43701abfd68f05cd3bf1a85117b96c4ecc58299
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/__init__.py
@@ -0,0 +1 @@
+# from .model import PointLLMLlamaForCausalLM
diff --git a/ThirdParty/PointLLM/pointllm/conversation.py b/ThirdParty/PointLLM/pointllm/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5350627507c3ef2f6f36f4a99ca3671f2995d1c8
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/conversation.py
@@ -0,0 +1,375 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ MPT = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ skip_next: bool = False
+
+ def reset(self):
+ self.messages = self.messages[:self.offset]
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ return ret
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ return ret
+ if self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def pop_last_none_message(self):
+ # * pop the last message if it's None, this is used for multi-round dialogue
+ if self.messages[-1][1] is None:
+ self.messages.pop()
+
+ def get_images(self, return_pil=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ from PIL import Image
+ msg, image, image_process_mode = msg
+ if image_process_mode == "Pad":
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image)
+ elif image_process_mode == "Crop":
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((224, 224))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ images.append(image)
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ images.append(img_b64_str)
+ return images
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ msg, image, image_process_mode = msg
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ # image = image.resize((224, 224))
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ img_str = f'
'
+ msg = msg.replace('', img_str)
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Give three tips for staying healthy."),
+ ("Assistant",
+ "Sure, here are three tips for staying healthy:\n"
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
+ "activities at least two days per week.\n"
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
+ "and aim to drink plenty of water throughout the day.\n"
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
+ "help improve the quality of your sleep.")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_v1_2 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+ ("Assistant",
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ "renewable and non-renewable energy sources:\n"
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ "energy sources are finite and will eventually run out.\n"
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ "and other negative effects.\n"
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ "have lower operational costs than non-renewable sources.\n"
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ "locations than non-renewable sources.\n"
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_vicuna_v1_1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_mpt = Conversation(
+ system="""<|im_start|>system
+- You are a helpful language and vision assistant.
+- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
+- You should follow the instructions carefully and explain your answers in detail.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_mpt_text = Conversation(
+ system="""<|im_start|>system
+- You are a helpful assistant chatbot trained by MosaicML.
+- You answer questions.
+- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
+- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_bair_v1 = Conversation(
+ system="BEGINNING OF CONVERSATION:",
+ roles=("USER", "GPT"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+simple_conv = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Hi!"),
+ ("Assistant", "Hi there! How can I help you today?")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+simple_conv_multimodal = Conversation(
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "Follow the instructions carefully and explain your answers in detail.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Hi!"),
+ ("Assistant", "Hi there! How can I help you today?\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+simple_conv_mpt_multimodal = Conversation(
+ system="""<|im_start|>system
+- You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
+- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
+- You should follow the instructions carefully and explain your answers in detail.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+simple_conv_legacy = Conversation(
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
+ "You are designed to assist human with a variety of tasks using natural language."
+ "Follow the instructions carefully.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Hi!\n\n### Response:"),
+ ("Assistant", "Hi there! How can I help you today?\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_llava_v1 = Conversation(
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "Follow the instructions carefully and explain your answers in detail.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+default_conversation = conv_v1_2
+conv_templates = {
+ "default": conv_v1_2,
+ "simple": simple_conv,
+ "simple_legacy": simple_conv_legacy,
+ "multimodal": simple_conv_multimodal,
+ "mpt_multimodal": simple_conv_mpt_multimodal,
+ "llava_v1": conv_llava_v1,
+
+ # fastchat
+ "v1": conv_v1_2,
+ "bair_v1": conv_bair_v1,
+ "vicuna_v1_1": conv_vicuna_v1_1,
+ "mpt": conv_mpt,
+ "mpt_text": conv_mpt_text,
+}
+
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
diff --git a/ThirdParty/PointLLM/pointllm/data/__init__.py b/ThirdParty/PointLLM/pointllm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2681ef21d7b4c758651eda7320bec4b5cbfc5b20
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/data/__init__.py
@@ -0,0 +1,3 @@
+from .utils import load_objaverse_point_cloud, pc_norm, farthest_point_sample
+from .object_point_dataset import ObjectPointCloudDataset, make_object_point_data_module
+from .modelnet import ModelNet
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/data/modelnet.py b/ThirdParty/PointLLM/pointllm/data/modelnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae42d24ec0a41c53bde71054176180e0e5c4bbce
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/data/modelnet.py
@@ -0,0 +1,147 @@
+import os
+import torch
+import numpy as np
+import pickle
+from torch.utils.data import Dataset
+from pointllm.utils import *
+from pointllm.data.utils import *
+
+class ModelNet(Dataset):
+ def __init__(self, config_path, split, subset_nums=-1, use_color=False):
+ """
+ Args:
+ data_args:
+ split: train or test
+ """
+ super(ModelNet, self).__init__()
+
+ if config_path is None:
+ # * use the default config file in the same dir
+ config_path = os.path.join(os.path.dirname(__file__), "modelnet_config", "ModelNet40.yaml")
+
+ config = cfg_from_yaml_file(config_path)
+ # * check data path
+ self.root = config["DATA_PATH"]
+
+ if not os.path.exists(self.root):
+ print(f"Data path {self.root} does not exist. Please check your data path.")
+ exit()
+
+ self.npoints = config.npoints
+ self.num_category = config.NUM_CATEGORY # * should be 40
+ self.random_sample = config.random_sampling
+ self.use_height = config.use_height
+ self.use_normals = config.USE_NORMALS
+ self.subset_nums = subset_nums
+ self.normalize_pc = True
+ self.use_color = use_color
+
+ if self.use_height or self.use_normals:
+ print(f"Warning: Usually we don't use height or normals for shapenet but use_height: {self.use_height} and \
+ use_normals: {self.use_normals}.")
+
+ self.split = split
+ assert (self.split == 'train' or self.split == 'test')
+
+ self.catfile = os.path.join(os.path.dirname(__file__), "modelnet_config", 'modelnet40_shape_names_modified.txt')
+
+ # "tv_stand" -> "tv stand"
+ self.categories = [line.rstrip() for line in open(self.catfile)] # * list of category names
+
+ self.save_path = os.path.join(self.root,
+ 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, self.split, self.npoints))
+
+ print('Load processed data from %s...' % self.save_path)
+ with open(self.save_path, 'rb') as f:
+ self.list_of_points, self.list_of_labels = pickle.load(f) # * ndarray of N, C: (8192, 6) (xyz and normals)
+
+ if self.subset_nums > 0:
+ # * set random seed
+ import random
+ random.seed(0)
+ # * random choose subset_nums
+ idxs = random.sample(range(len(self.list_of_labels)), self.subset_nums)
+ self.list_of_labels = [self.list_of_labels[idx] for idx in idxs]
+ self.list_of_points = [self.list_of_points[idx] for idx in idxs]
+
+ # * print len
+ print(f"Load {len(self.list_of_points)} data from {self.save_path}.")
+
+ def __len__(self):
+ return len(self.list_of_labels)
+
+ def _get_item(self, index):
+ point_set, label = self.list_of_points[index], self.list_of_labels[index]
+
+ if self.npoints < point_set.shape[0]:
+ if self.random_sample:
+ # * random sample
+ point_set = point_set[np.random.choice(point_set.shape[0], self.npoints, replace=False)]
+ else:
+ point_set = farthest_point_sample(point_set, self.npoints)
+
+ point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
+ if not self.use_normals:
+ point_set = point_set[:, 0:3]
+
+ if self.use_height:
+ self.gravity_dim = 1
+ height_array = point_set[:, self.gravity_dim:self.gravity_dim + 1] - point_set[:,
+ self.gravity_dim:self.gravity_dim + 1].min()
+ point_set = np.concatenate((point_set, height_array), axis=1)
+
+ point_set = np.concatenate((point_set, np.zeros_like(point_set)), axis=-1) if self.use_color else point_set
+
+ return point_set, label.item() # * ndarray, int
+
+ def pc_norm(self, pc):
+ """ pc: NxC, return NxC """
+ xyz = pc[:, :3]
+ other_feature = pc[:, 3:]
+
+ centroid = np.mean(xyz, axis=0)
+ xyz = xyz - centroid
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
+ xyz = xyz / m
+
+ pc = np.concatenate((xyz, other_feature), axis=1)
+ return pc
+
+ def __getitem__(self, index):
+ points, label = self._get_item(index)
+ pt_idxs = np.arange(0, points.shape[0]) # 2048
+ if self.split == 'train':
+ np.random.shuffle(pt_idxs)
+ current_points = points[pt_idxs].copy()
+
+ if self.normalize_pc:
+ # * modelnet point cloud is already normalized
+ current_points = self.pc_norm(current_points)
+
+ current_points = torch.from_numpy(current_points).float() # * N, C tensors
+ label_name = self.categories[int(label)]
+
+ data_dict = {
+ "indice": index, # * int
+ "point_clouds": current_points, # * tensor of N, C
+ "labels": label, # * int
+ "label_names": label_name # * str
+ }
+
+ return data_dict
+
+if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser(description='ModelNet Dataset')
+
+ parser.add_argument("--config_path", type=str, default=None, help="config file path.")
+ parser.add_argument("--split", type=str, default="test", help="train or test.")
+ parser.add_argument("--subset_nums", type=int, default=200)
+
+ args = parser.parse_args()
+
+ dataset = ModelNet(config_path=args.config_path, split=args.split, subset_nums=args.subset_nums)
+
+ # * get the first item
+ print(dataset[0])
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/data/modelnet_config/ModelNet40.yaml b/ThirdParty/PointLLM/pointllm/data/modelnet_config/ModelNet40.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1519c08a16dd78c8bb17cef58e138048534c37f7
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/data/modelnet_config/ModelNet40.yaml
@@ -0,0 +1,8 @@
+NAME: ModelNet
+DATA_PATH: data/modelnet40_data
+NUM_CATEGORY: 40
+USE_NORMALS: FALSE
+npoints: 8192
+random_sampling: TRUE
+use_height: FALSE
+use_normals: FALSE
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/data/object_point_dataset.py b/ThirdParty/PointLLM/pointllm/data/object_point_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ab0f30ece7ff860df70abce0151d918b82d1e6a
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/data/object_point_dataset.py
@@ -0,0 +1,250 @@
+import os
+import json
+import torch
+import numpy as np
+
+import copy
+import transformers
+from torch.utils.data import Dataset
+
+from .utils import *
+
+
+def make_object_point_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
+ """Make dataset and collator for Joint3Ddataset with text and point cloud data."""
+ """Initialize datasets."""
+
+ data_collator = DataCollatorForPointTextDataset(tokenizer=tokenizer)
+ if data_args.split_train_val:
+ print("Loading training datasets.")
+ train_dataset = ObjectPointCloudDataset(
+ split='train',
+ data_path=data_args.data_path,
+ anno_path=data_args.anno_path,
+ pointnum=data_args.pointnum,
+ conversation_types=data_args.conversation_types,
+ tokenizer=tokenizer,
+ use_color=data_args.use_color,
+ data_args=data_args
+ )
+ print("Done!")
+ if data_args.data_debug_num > 0:
+ print('Debug mode, using training set as val set.')
+ val_dataset = train_dataset
+ else:
+ # * make a val dataset
+ print("Loading validation datasets.")
+ val_dataset = ObjectPointCloudDataset(
+ split='val', # * load train split
+ data_path=data_args.data_path,
+ anno_path=data_args.anno_path,
+ pointnum=data_args.pointnum,
+ conversation_types=data_args.conversation_types,
+ tokenizer=tokenizer,
+ use_color=data_args.use_color,
+ data_args=data_args
+ )
+ return dict(train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=data_collator)
+ else:
+ # * use all data as training data
+ train_dataset = ObjectPointCloudDataset(
+ split='train',
+ data_path=data_args.data_path,
+ anno_path=data_args.anno_path,
+ pointnum=data_args.pointnum,
+ conversation_types=data_args.conversation_types,
+ use_color=data_args.use_color,
+ tokenizer=tokenizer,
+ data_args=data_args
+ )
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
+
+class ObjectPointCloudDataset(Dataset):
+ """Dataset utilities for objaverse."""
+ def __init__(self,
+ data_path=None,
+ anno_path=None,
+ tokenizer=None,
+ pointnum=8192,
+ split='train',
+ conversation_types=None, # * default is simple_des, used for stage1 pre-train
+ use_color=True,
+ data_args=None):
+
+ """
+ split: only considered when data_args.split_train_val is True.
+ conversation_types: tuple, used to filter the data, default is ('simple_description'), other types is:
+ "detailed_description", "single_round", "multi_round".
+ tokenizer: load point clouds only if None
+ """
+ super(ObjectPointCloudDataset, self).__init__()
+
+ """Initialize dataset with object point clouds and text"""
+ self.data_path = data_path
+ self.anno_path = anno_path
+ self.tokenizer = tokenizer
+ self.split = split
+ if conversation_types is None:
+ self.conversation_types = ("simple_description",)
+ else:
+ self.conversation_types = conversation_types
+
+ self.data_args = data_args
+ self.normalize_pc = True
+ self.use_color = use_color
+
+ self.pointnum = pointnum
+ self.point_backbone_config = data_args.point_backbone_config if data_args is not None else None
+ self.point_indicator = ''
+
+ # Load the data list from JSON
+ print(f"Loading anno file from {anno_path}.")
+ with open(anno_path, "r") as json_file:
+ self.list_data_dict = json.load(json_file)
+
+ # * print the conversations_type
+ print(f"Using conversation_type: {self.conversation_types}")
+ # * print before filtering
+ print(f"Before filtering, the dataset size is: {len(self.list_data_dict)}.")
+
+ # * iterate the list and filter
+ # * these two ids have corrupted colored point files, so filter them when use_color is True
+ filter_ids = ['6760e543e1d645d5aaacd3803bcae524', 'b91c0711149d460a8004f9c06d3b7f38'] if self.use_color else []
+
+ # Iterate the list, filter those "conversation_type" not in self.conversation_types
+ self.list_data_dict = [
+ data for data in self.list_data_dict
+ if data.get('conversation_type', 'simple_description') in self.conversation_types
+ and data.get('object_id') not in filter_ids
+ ]
+
+ # * print after filtering
+ print(f"After filtering, the dataset size is: {len(self.list_data_dict)}.")
+ # * print the size of different conversation_type
+ for conversation_type in self.conversation_types:
+ print(f"Number of {conversation_type}: {len([data for data in self.list_data_dict if data.get('conversation_type', 'simple_description') == conversation_type])}")
+
+ if self.data_args is not None and self.data_args.data_debug_num > 0:
+ self.list_data_dict = self.list_data_dict[:self.data_args.data_debug_num]
+ # * print all the scan_id in debug mode, not using for loop
+ print('Debug mode, using: ' + ' '.join([data['object_id'] for data in self.list_data_dict]))
+ elif self.data_args is not None and self.data_args.split_train_val:
+ # * split train and val with 9:1 ratios
+ if self.split == 'train':
+ self.list_data_dict = self.list_data_dict[:int(self.data_args.split_ratio * len(self.list_data_dict))]
+ print(f"Train set size: {len(self.list_data_dict)}")
+ else:
+ self.list_data_dict = self.list_data_dict[int(self.data_args.split_ratio * len(self.list_data_dict)):]
+ print(f"Val set size: {len(self.list_data_dict)}")
+
+ def _load_point_cloud(self, object_id, type='objaverse'):
+ if type == 'objaverse':
+ return self._load_objaverse_point_cloud(object_id)
+
+ def _load_objaverse_point_cloud(self, object_id):
+ filename = f"{object_id}_{self.pointnum}.npy"
+ point_cloud = np.load(os.path.join(self.data_path, filename))
+
+ if not self.use_color:
+ point_cloud = point_cloud[:, :3]
+
+ return point_cloud
+
+ def pc_norm(self, pc):
+ """ pc: NxC, return NxC """
+ xyz = pc[:, :3]
+ other_feature = pc[:, 3:]
+
+ centroid = np.mean(xyz, axis=0)
+ xyz = xyz - centroid
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
+ xyz = xyz / m
+
+ pc = np.concatenate((xyz, other_feature), axis=1)
+ return pc
+
+ def __getitem__(self, index):
+ sources = self.list_data_dict[index]
+ if isinstance(index, int):
+ sources = [sources]
+ assert len(sources) == 1, "sources should be a list"
+ if self.point_indicator in sources[0]['conversations'][0]['value']:
+
+ object_id = self.list_data_dict[index]['object_id']
+
+ # Point cloud representation
+ point_cloud = self._load_point_cloud(object_id) # * N, C
+ if self.normalize_pc:
+ point_cloud = self.pc_norm(point_cloud) # * need to norm since point encoder is norm
+
+ if self.tokenizer is None:
+ data_dict = dict(
+ point_clouds=torch.from_numpy(point_cloud.astype(np.float32)),
+ object_ids=object_id
+ )
+ return data_dict
+
+ sources = preprocess_multimodal_point_cloud(
+ copy.deepcopy([e["conversations"] for e in sources]), self.point_backbone_config, point_indicator=self.point_indicator)
+ else:
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+
+ data_dict = preprocess_v1(
+ sources,
+ self.tokenizer)
+
+ if isinstance(index, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+
+ # point exist in the data
+ if self.point_indicator in self.list_data_dict[index]['conversations'][0]['value']:
+ data_dict['point_clouds'] = torch.from_numpy(point_cloud.astype(np.float32))
+
+ return data_dict
+
+ def __len__(self):
+ """Return number of utterances."""
+ return len(self.list_data_dict)
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--data_path", default="data/objaverse_data", type=str,
+ help="Path to the data directory.")
+ parser.add_argument("--anno_path", default=None, type=str, required=True,
+ help="Path to the annotation file.")
+ parser.add_argument("--split", default='train', type=str,
+ help="Whether to use the train or validation dataset.")
+ parser.add_argument("--pointnum", default=8192, type=int,
+ help="Number of points in the point cloud.")
+ parser.add_argument("--data_debug_num", default=0, type=int,
+ help="Number of data to debug with.")
+ parser.add_argument("--split_train_val", default=False, type=bool,
+ help="Whether to split the dataset into training and validation.")
+ parser.add_argument("--split_ratio", default=0.9, type=float,
+ help="The ratio of training to validation data.")
+ parser.add_argument("--tokenizer_path", default=None, type=str, required=True,
+ help="Path to the tokenizer config file.")
+
+ args = parser.parse_args()
+
+ # Initialize tokenizer
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_path)
+
+ args.point_backbone_config = None
+
+ # Initialize dataset
+ dataset = ObjectPointCloudDataset(
+ data_path=args.data_path,
+ anno_path=args.anno_path,
+ pointnum=args.pointnum,
+ split=args.split,
+ tokenizer=tokenizer,
+ data_args=args
+ )
+
+ # Example usage
+ print(f'Dataset length: {len(dataset)}')
+
diff --git a/ThirdParty/PointLLM/pointllm/data/utils.py b/ThirdParty/PointLLM/pointllm/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c41aaca765e4e670207ee798807ec64c65730a48
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/data/utils.py
@@ -0,0 +1,236 @@
+from collections import OrderedDict, defaultdict
+
+import transformers
+from pointllm import conversation as conversation_lib
+from dataclasses import dataclass
+from typing import Optional, Dict, Sequence
+import torch
+
+import numpy as np
+import os
+
+IGNORE_INDEX = -100
+
+# * Sample Usage:
+# * from utils import LRUCache
+# * cache = LRUCache(capacity, max_access_count)
+# if self.cache is None:
+# info_data = self.multiview_scannet[info_index]
+# else:
+# info_data = self.cache.get(info_index)
+# if info_data is None or self.cache.get_access_count(info_index) >= self.cache.max_access_count:
+# # If not in cache, or accessed max_access_count times, load it and put it in cache
+# info_data = self.multiview_scannet[info_index]
+# self.cache.put(info_index, info_data)
+# self.cache.reset_access_count(info_index)
+
+class LRUCache:
+ def __init__(self, capacity, max_access_count):
+ self.cache = OrderedDict()
+ self.access_count = defaultdict(int)
+ self.capacity = capacity
+ self.max_access_count = max_access_count
+
+ def get(self, key):
+ if key not in self.cache:
+ return None
+ value = self.cache.pop(key)
+ self.cache[key] = value # Put key as the newest one
+ self.access_count[key] += 1
+ return value
+
+ def put(self, key, value):
+ if key in self.cache: # Update the value and put it as newest
+ self.cache.pop(key)
+ elif len(self.cache) == self.capacity: # If cache is full
+ oldest_key = next(iter(self.cache))
+ self.cache.popitem(last=False) # Remove oldest item
+ del self.access_count[oldest_key] # Remove the corresponding access count
+ self.cache[key] = value
+ self.access_count[key] = 1
+
+ def get_access_count(self, key):
+ return self.access_count.get(key, 0)
+
+ def reset_access_count(self, key):
+ self.access_count[key] = 0
+
+
+def preprocess_v1(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2: # * can handle padded tokens
+ break
+ parts[0] += sep
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX # * this is necessary for padded tokens
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len: # * unk tokens in the dialogue will cause this.
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+def preprocess_multimodal_point_cloud(
+ sources: Sequence[str],
+ point_backbone_config: dict,
+ point_indicator: str = "",
+) -> Dict:
+ point_token_len = point_backbone_config['point_token_len']
+ default_point_patch_token = point_backbone_config['default_point_patch_token']
+
+ for source in sources:
+ for sentence in source:
+ replace_token = default_point_patch_token * point_token_len
+ if point_backbone_config['mm_use_point_start_end']:
+ replace_token = point_backbone_config['default_point_start_token']+ replace_token + point_backbone_config['default_point_end_token']
+ sentence["value"] = sentence["value"].replace(point_indicator, replace_token)
+
+ return sources
+
+def pc_norm(pc):
+ """ pc: NxC, return NxC """
+ xyz = pc[:, :3]
+ other_feature = pc[:, 3:]
+
+ centroid = np.mean(xyz, axis=0)
+ xyz = xyz - centroid
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
+ xyz = xyz / m
+
+ pc = np.concatenate((xyz, other_feature), axis=1)
+ return pc
+
+def load_objaverse_point_cloud(data_path, object_id, pointnum=8192, use_color=False):
+ filename = f"{object_id}_{pointnum}.npy"
+ point_cloud = np.load(os.path.join(data_path, filename))
+
+ # * normalize
+ point_cloud = pc_norm(point_cloud)
+
+ if not use_color:
+ point_cloud = point_cloud[:, :3]
+
+ return point_cloud
+
+@dataclass
+class DataCollatorForPointTextDataset(object):
+ """Collate examples for mixed dataset with text and point cloud data."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'point_clouds' in instances[0]:
+ point_clouds = [instance['point_clouds'] for instance in instances]
+ if all(x is not None and x.shape == point_clouds[0].shape for x in point_clouds): # * point_clouds have different shapes
+ batch['point_clouds'] = torch.stack(point_clouds)
+ else:
+ batch['point_clouds'] = point_clouds # * return as lists
+
+ return batch
+
+def farthest_point_sample(point, npoint):
+ """
+ Input:
+ xyz: pointcloud data, [N, D]
+ npoint: number of samples
+ Return:
+ centroids: sampled pointcloud index, [npoint, D]
+ """
+ N, D = point.shape
+ xyz = point[:,:3]
+ centroids = np.zeros((npoint,))
+ distance = np.ones((N,)) * 1e10
+ farthest = np.random.randint(0, N)
+ for i in range(npoint):
+ centroids[i] = farthest
+ centroid = xyz[farthest, :]
+ dist = np.sum((xyz - centroid) ** 2, -1)
+ mask = dist < distance
+ distance[mask] = dist[mask]
+ farthest = np.argmax(distance, -1)
+ point = point[centroids.astype(np.int32)]
+ return point
+
+def pc_normalize(pc):
+ """
+ pc: Nx3 array
+ This functions normalizes a point cloud to fit within a unit sphere.
+ It first calculates the centroid of the point cloud and then subtracts
+ it from all points before scaling all points to fit within a unit sphere.
+ """
+ centroid = np.mean(pc, axis=0)
+ pc = pc - centroid
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
+ pc = pc / m
+ return pc
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/eval/PointLLM_chat.py b/ThirdParty/PointLLM/pointllm/eval/PointLLM_chat.py
new file mode 100644
index 0000000000000000000000000000000000000000..920a271c1fd3c784d69055e681ca7091951d78a8
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/eval/PointLLM_chat.py
@@ -0,0 +1,157 @@
+import argparse
+from transformers import AutoTokenizer
+import torch
+import os
+from pointllm.conversation import conv_templates, SeparatorStyle
+from pointllm.utils import disable_torch_init
+from pointllm.model import *
+from pointllm.model.utils import KeywordsStoppingCriteria
+
+from pointllm.data import load_objaverse_point_cloud
+
+import os
+
+def load_point_cloud(args):
+ object_id = args.object_id
+ print(f"[INFO] Loading point clouds using object_id: {object_id}")
+ point_cloud = load_objaverse_point_cloud(args.data_path, object_id, pointnum=8192, use_color=True)
+
+ return object_id, torch.from_numpy(point_cloud).unsqueeze_(0).to(torch.float32)
+
+def init_model(args):
+ # Model
+ disable_torch_init()
+
+ model_path = args.model_path
+ print(f'[INFO] Model name: {model_path}')
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ model = PointLLMLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=False, use_cache=True, torch_dtype=args.torch_dtype).cuda()
+ model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer)
+
+ model.eval()
+
+ mm_use_point_start_end = getattr(model.config, "mm_use_point_start_end", False)
+ # Add special tokens ind to model.point_config
+ point_backbone_config = model.get_model().point_backbone_config
+
+ if mm_use_point_start_end:
+ if "v1" in model_path.lower():
+ conv_mode = "vicuna_v1_1"
+ else:
+ raise NotImplementedError
+
+ conv = conv_templates[conv_mode].copy()
+
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
+ keywords = [stop_str]
+
+ return model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv
+
+def start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv):
+ point_token_len = point_backbone_config['point_token_len']
+ default_point_patch_token = point_backbone_config['default_point_patch_token']
+ default_point_start_token = point_backbone_config['default_point_start_token']
+ default_point_end_token = point_backbone_config['default_point_end_token']
+ # The while loop will keep running until the user decides to quit
+ print("[INFO] Starting conversation... Enter 'q' to exit the program and enter 'exit' to exit the current conversation.")
+ while True:
+ print("-" * 80)
+ # Prompt for object_id
+ object_id = input("[INFO] Please enter the object_id or 'q' to quit: ")
+
+ # Check if the user wants to quit
+ if object_id.lower() == 'q':
+ print("[INFO] Quitting...")
+ break
+ else:
+ # print info
+ print(f"[INFO] Chatting with object_id: {object_id}.")
+
+ # Update args with new object_id
+ args.object_id = object_id.strip()
+
+ # Load the point cloud data
+ try:
+ id, point_clouds = load_point_cloud(args)
+ except Exception as e:
+ print(f"[ERROR] {e}")
+ continue
+ point_clouds = point_clouds.cuda().to(args.torch_dtype)
+
+ # Reset the conversation template
+ conv.reset()
+
+ print("-" * 80)
+
+ # Start a loop for multiple rounds of dialogue
+ for i in range(100):
+ # This if-else block ensures the initial question from the user is included in the conversation
+ qs = input(conv.roles[0] + ': ')
+ if qs == 'exit':
+ break
+
+ if i == 0:
+ if mm_use_point_start_end:
+ qs = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + qs
+ else:
+ qs = default_point_patch_token * point_token_len + '\n' + qs
+
+ # Append the new message to the conversation history
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+ inputs = tokenizer([prompt])
+
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
+
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
+ stop_str = keywords[0]
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ point_clouds=point_clouds,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ max_length=2048,
+ top_p=0.95,
+ stopping_criteria=[stopping_criteria])
+
+ input_token_len = input_ids.shape[1]
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
+ if n_diff_input_output > 0:
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
+ outputs = outputs.strip()
+ if outputs.endswith(stop_str):
+ outputs = outputs[:-len(stop_str)]
+ outputs = outputs.strip()
+
+ # Append the model's response to the conversation history
+ conv.pop_last_none_message()
+ conv.append_message(conv.roles[1], outputs)
+ print(f'{conv.roles[1]}: {outputs}\n')
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_name", type=str, \
+ default="RunsenXu/PointLLM_7B_v1.2")
+
+ parser.add_argument("--data_path", type=str, default="data/objaverse_data")
+ parser.add_argument("--torch_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"])
+
+ args = parser.parse_args()
+
+ dtype_mapping = {
+ "float32": torch.float32,
+ "float16": torch.float16,
+ "bfloat16": torch.bfloat16,
+ }
+
+ args.torch_dtype = dtype_mapping[args.torch_dtype]
+
+ model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv = init_model(args)
+
+ start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv)
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/eval/chat_gradio.py b/ThirdParty/PointLLM/pointllm/eval/chat_gradio.py
new file mode 100644
index 0000000000000000000000000000000000000000..10ab00c4d9a38ff9030b40f0cb27b88b171d58de
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/eval/chat_gradio.py
@@ -0,0 +1,394 @@
+import argparse
+from transformers import AutoTokenizer
+import torch
+import os
+from pointllm.conversation import conv_templates, SeparatorStyle
+from pointllm.utils import disable_torch_init
+from pointllm.model import *
+from pointllm.model.utils import KeywordsStoppingCriteria
+import numpy as np
+
+from pointllm.data import pc_norm, farthest_point_sample
+
+import os
+
+# Additional import for gradio
+import gradio as gr
+import open3d as o3d
+import plotly.graph_objects as go
+import objaverse
+import time
+
+import logging
+
+
+def change_input_method(input_method):
+ if input_method == 'File':
+ result = [gr.update(visible=True),
+ gr.update(visible=False)]
+ elif input_method == 'Object ID':
+ result = [gr.update(visible=False),
+ gr.update(visible=True)]
+ return result
+
+def init_model(args):
+ # Model
+ disable_torch_init()
+ model_name = os.path.expanduser(args.model_name)
+
+ # * print the model_name (get the basename)
+ print(f'[INFO] Model name: {os.path.basename(model_name)}')
+ logging.warning(f'Model name: {os.path.basename(model_name)}')
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ model = PointLLMLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=False, use_cache=True).cuda()
+ model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer)
+
+ model.eval()
+
+ mm_use_point_start_end = getattr(model.config, "mm_use_point_start_end", False)
+ # Add special tokens ind to model.point_config
+ point_backbone_config = model.get_model().point_backbone_config
+
+ conv = conv_templates["vicuna_v1_1"].copy()
+
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
+ keywords = [stop_str]
+
+ return model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv
+
+def start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv):
+ point_token_len = point_backbone_config['point_token_len']
+ default_point_patch_token = point_backbone_config['default_point_patch_token']
+ default_point_start_token = point_backbone_config['default_point_start_token']
+ default_point_end_token = point_backbone_config['default_point_end_token']
+
+ # The while loop will keep running until the user decides to quit
+ print("[INFO] Starting conversation...")
+ logging.warning("Starting conversation...")
+ while True:
+ print("-" * 80)
+ logging.warning("-" * 80)
+
+ # Reset the conversation template
+ conv.reset()
+
+ def confirm_point_cloud(input_choice, object_id_input, point_cloud_input, chatbot, answer_time, conv):
+ objects = None
+ data = None
+ object_id_input = object_id_input.strip()
+
+ print("%" * 80)
+ logging.warning("%" * 80)
+
+ if input_choice == 'File':
+ file = point_cloud_input.name
+ print(f"Uploading file: {file}.")
+ logging.warning(f"Uploading file: {file}.")
+ elif input_choice == 'Object ID':
+ file = os.path.join(args.data_path, "{}_8192.npy".format(object_id_input))
+ print(f"Object_id: {object_id_input}")
+ logging.warning(f"Object_id: {object_id_input}")
+
+ object_uids = [object_id_input]
+ objects = objaverse.load_objects(uids=object_uids)
+ print("%" * 80)
+ logging.warning("%" * 80)
+
+ manual_no_color = "no_color" in file
+
+ try:
+ if '.ply' in file:
+ pcd = o3d.io.read_point_cloud(file)
+ points = np.asarray(pcd.points) # xyz
+ colors = np.asarray(pcd.colors) # rgb, if available
+ # * if no colors actually, empty array
+ if colors.size == 0:
+ colors = None
+ elif '.npy' in file:
+ data = np.load(file)
+ if data.shape[1] >= 3:
+ points = data[:, :3]
+ else:
+ raise ValueError("Input array has the wrong shape. Expected: [N, 3]. Got: {}.".format(data.shape))
+ colors = None if data.shape[1] < 6 else data[:, 3:6]
+ else:
+ raise ValueError("Not supported data format.")
+ # error
+ except Exception as e:
+ print(f"[ERROR] {e}")
+ logging.warning(f"[ERROR] {e}")
+
+ chatbot_system_message = "Sorry. The Objaverse id is not supported or the uploaded file has something wrong!"
+ print(f"[ChatBot System Message]: {chatbot_system_message}")
+ logging.warning(f"[ChatBot System Message]: {chatbot_system_message}")
+
+ outputs = f"[System] {chatbot_system_message}" # "You upload a new Points Cloud"
+ chatbot = chatbot + [[None, outputs]]
+
+ return None, None, chatbot, answer_time, None
+
+ if manual_no_color:
+ colors = None
+
+ if colors is not None:
+ # * if colors in range(0-1)
+ if np.max(colors) <= 1:
+ color_data = np.multiply(colors, 255).astype(int) # Convert float values (0-1) to integers (0-255)
+ # * if colors in range(0-255)
+ elif np.max(colors) <= 255:
+ color_data = colors.astype(int)
+ else:
+ color_data = np.zeros_like(points).astype(int) # Default to black color if RGB information is not available
+ colors = color_data.astype(np.float32) / 255 # model input is (0-1)
+
+ # Convert the RGB color data to a list of RGB strings in the format 'rgb(r, g, b)'
+ color_strings = ['rgb({},{},{})'.format(r, g, b) for r, g, b in color_data]
+
+ fig = go.Figure(
+ data=[
+ go.Scatter3d(
+ x=points[:, 0], y=points[:, 1], z=points[:, 2],
+ mode='markers',
+ marker=dict(
+ size=1.2,
+ color=color_strings, # Use the list of RGB strings for the marker colors
+ )
+ )
+ ],
+ layout=dict(
+ scene=dict(
+ xaxis=dict(visible=False),
+ yaxis=dict(visible=False),
+ zaxis=dict(visible=False)
+ ),
+ paper_bgcolor='rgb(255,255,255)' # Set the background color to dark gray 50, 50, 50
+ ),
+ )
+
+ points = np.concatenate((points, colors), axis=1)
+ if 8192 < points.shape[0]:
+ points = farthest_point_sample(points, 8192)
+ point_clouds = pc_norm(points)
+ point_clouds = torch.from_numpy(point_clouds).unsqueeze_(0).to(torch.float32).cuda()
+
+ answer_time = 0
+ conv.reset()
+
+ outputs = "[System] New Point Cloud"
+ chatbot = chatbot + [[None, outputs]]
+
+ return fig, list(objects.values())[0] if objects is not None else None, chatbot, answer_time, point_clouds
+
+ def answer_generate(history, answer_time, point_clouds, conv):
+ if point_clouds is None:
+ outputs = "[System] Please input point cloud! "
+ history[-1][1] = outputs
+ yield history
+ else:
+ print(f"Answer Time: {answer_time}")
+ logging.warning(f"Answer Time: {answer_time}")
+ input_text = history[-1][0]
+ qs = input_text
+
+ if answer_time == 0:
+ if mm_use_point_start_end:
+ qs = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + qs
+ else:
+ qs = default_point_patch_token * point_token_len + '\n' + qs
+
+ # Append the new message to the conversation history
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+ print("#" * 80)
+ print(f'{prompt.replace("" * point_token_len, f" * {point_token_len}")}') # for concise printing
+ print("#" * 80)
+
+ logging.warning("#" * 80)
+ logging.warning(f'{prompt.replace("" * point_token_len, f" * {point_token_len}")}') # for concise printing
+ logging.warning("#" * 80)
+ inputs = tokenizer([prompt])
+
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
+
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
+ stop_str = keywords[0]
+
+ try:
+ if input_ids.shape[1] >= 2047:
+ raise ValueError("Current context length exceeds the maximum context length (2048) of the model.")
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ point_clouds=point_clouds,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ max_length=2048,
+ top_p=0.95,
+ stopping_criteria=[stopping_criteria])
+
+ input_token_len = input_ids.shape[1]
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
+ if n_diff_input_output > 0:
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
+ logging.warning(f'{n_diff_input_output} output_ids are not the same as the input_ids')
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
+ outputs = outputs.strip()
+ if outputs.endswith(stop_str):
+ outputs = outputs[:-len(stop_str)]
+ outputs = outputs.strip()
+
+ # Append the model's response to the conversation history
+ conv.pop_last_none_message()
+ conv.append_message(conv.roles[1], outputs)
+ print(f'{conv.roles[1]}: {outputs}\n')
+ logging.warning(f'{conv.roles[1]}: {outputs}\n')
+ answer_time += 1
+ history[-1][1] = ""
+ for character in outputs:
+ history[-1][1] += character
+ yield history
+ # error
+ except Exception as e:
+ print(f"[ERROR] {e}")
+ logging.warning(f"[ERROR] {e}")
+
+ if input_ids.shape[1] >= 2047:
+ chatbot_system_message = "Current context length exceeds the maximum context length (2048) of the model. Please press 'Clear' to restart."
+ else:
+ chatbot_system_message = "Sorry. There is something wrong when generating. Please check the your uploaded point cloud or the Objaverse id, and \
+ confirm the point cloud again."
+ print(f"[ChatBot System Message]: {chatbot_system_message}")
+ logging.warning(f"[ChatBot System Message]: {chatbot_system_message}")
+
+ outputs = f"[System] {chatbot_system_message}" # "You upload a new Points Cloud"
+ history[-1][1] = outputs
+ yield history
+
+ with gr.Blocks() as demo:
+ answer_time = gr.State(value=0)
+ point_clouds = gr.State(value=None)
+ conv_state = gr.State(value=conv.copy())
+ gr.Markdown(
+ """
+ # PointLLM: Empowering Large Language Models to Understand Point Clouds. 🚀
+ If you think this demo interesting, please consider starring 🌟 our github repo. :)
+ [[Project Page](https://runsenxu.com/projects/PointLLM)] [[Paper](https://arxiv.org/abs/2308.16911)] [[Code](https://github.com/OpenRobotLab/PointLLM)]
+ """
+ )
+ with gr.Row():
+ with gr.Column():
+ input_choice = gr.Radio(['File', 'Object ID'], value='Object ID', interactive=True, label='Input Method', info="How do you want to load point clouds?")
+ object_id_input = gr.Textbox(visible = True,lines=1, label='Object ID Input')
+ point_cloud_input = gr.File(visible = False, label="Upload Point Cloud File (PLY, NPY)")
+ output = gr.Plot()
+ btn = gr.Button(value="Confirm Point Cloud")
+ model3D = gr.Model3D()
+ with gr.Column():
+ chatbot = gr.Chatbot([], elem_id="chatbot", height=560) # ,color_map=("green", "pink")
+
+ def user(user_message, history):
+ return "", history + [[user_message, None]]
+
+ def clear_conv(history, conv):
+ conv.reset()
+ return None, 0
+
+ with gr.Row():
+ text_input = gr.Textbox(
+ show_label=False,
+ placeholder="Enter text and press enter",
+ container=False,
+ )
+ run_button = gr.Button("Send")
+
+ clear = gr.Button("Clear")
+ text_input.submit(user, [text_input, chatbot], [text_input, chatbot], queue=False).then(answer_generate, [chatbot, answer_time, point_clouds, conv_state], chatbot).then(lambda x : x+1,answer_time, answer_time)
+ clear.click(clear_conv, inputs=[chatbot, conv_state], outputs=[chatbot, answer_time], queue=False)
+
+ btn.click(confirm_point_cloud, inputs=[input_choice, object_id_input, point_cloud_input, chatbot, answer_time, conv_state], outputs=[output, model3D, chatbot, answer_time, point_clouds])
+
+ input_choice.change(change_input_method, input_choice, [point_cloud_input, object_id_input])
+ run_button.click(user, [text_input, chatbot], [text_input, chatbot], queue=False).then(answer_generate, [chatbot, answer_time, point_clouds, conv_state], chatbot).then(lambda x : x+1, answer_time, answer_time)
+
+ gr.Markdown(
+ """
+ ### Usage:
+ 1. Upload your point cloud file (ply, npy only) or input the supported [Objaverse object id (uid)](https://drive.google.com/file/d/1gLwA7aHfy1KCrGeXlhICG9rT2387tWY8/view?usp=sharing) (currently 660K objects only, you may try the example object ids below).
+ 2. If your point cloud file does not contian colors, manually set the file name contains 'no_color' (e.g., 'xxx_no_color.npy'), and the black color will be assigned.
+ 3. If uploading your own point cloud file with color in npy format, the first three dimensions should be xyz, and the next three dimensions should be rgb. The rgb values should range from **0 to 1**.
+ 4. Click **Confirm Point Cloud**.
+ 5. As we use FPS sampling to downsample the point cloud to 8192 points, it may take a long time to confirm the point cloud if the point cloud has too many points. You may use random sampling to downsample the point cloud before uploading.
+ 6. Once '[System] New Point Cloud' appears in the dialogue box, a new conversation with PointLLM is initialized.
+ 7. The 'Clear' button will clear the conversation history.
+ """)
+ with gr.Accordion("Example Objaverse object ids in the validation set!", open=False):
+ example_object_ids = [ ["b4bbf2116b1a41a5a3b9d3622b07074c", "0b8da82a3d7a436f9b585436c4b72f56", "650c53d68d374c18886aab91bcf8bb54"],
+ ["983fa8b23a084f5dacd157e6c9ceba97", "8fe23dd4bf8542b49c3a574b33e377c3", "83cb2a9e9afb47cd9f45461613796645"],
+ ["3d679a3888c548afb8cf889915af7fd2", "7bcf8626eaca40e592ffd0aed08aa30b", "69865c89fc7344be8ed5c1a54dbddc20"],
+ ["252f3b3f5cd64698826fc1ab42614677", "e85ebb729b02402bbe3b917e1196f8d3", "97367c4740f64935b7a5e34ae1398035"],
+ ["fc8dd5a2fc9f4dd19ad6a64a8a6e89e9", "8257772b0e2f408ba269264855dfea00", "d6a3520486bb474f9b5e72eda8408974"],
+ ["3d10918e6a9a4ad395a7280c022ad2b9", "00002bcb84af4a4781174e62619f14e2", "76ba80230d454de996878c2763fe7e5c"]]
+ gr.DataFrame(
+ type="array",
+ headers=["Example Object IDs"] * 3,
+ row_count=6,
+ col_count=3,
+ value=example_object_ids
+ )
+ gr.Markdown(
+ """
+ #### Terms of use
+ By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
+ """
+ )
+ demo.queue()
+ demo.launch(server_name="0.0.0.0", server_port=args.port, share=False) # server_port=7832, share=True
+
+if __name__ == "__main__":
+ # ! To release this demo in public, make sure to start in a place where no important data is stored.
+ # ! Please check 1. the lanuch dir 2. the tmp dir (GRADIO_TEMP_DIR)
+ # ! refer to https://www.gradio.app/guides/sharing-your-app#security-and-file-access
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-name", type=str, \
+ default="RunsenXu/PointLLM_7B_v1.2")
+
+
+ parser.add_argument("--data_path", type=str, default="data/objaverse_data", required=False)
+ parser.add_argument("--pointnum", type=int, default=8192)
+
+ parser.add_argument("--log_file", type=str, default="serving_workdirs/serving_log.txt")
+ parser.add_argument("--tmp_dir", type=str, default="serving_workdirs/tmp")
+
+ # For gradio
+ parser.add_argument("--port", type=int, default=7810)
+
+ args = parser.parse_args()
+
+ # * make serving dirs
+ os.makedirs(os.path.dirname(args.log_file), exist_ok=True)
+ os.makedirs(args.tmp_dir, exist_ok=True)
+
+ # * add the current time for log name
+ args.log_file = args.log_file.replace(".txt", f"_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.txt")
+
+ logging.basicConfig(
+ filename=args.log_file,
+ level=logging.WARNING, # * default gradio is info, so use warning
+ format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S'
+ )
+
+ logging.warning("-----New Run-----")
+ logging.warning(f"args: {args}")
+
+ print("-----New Run-----")
+ print(f"[INFO] Args: {args}")
+
+ # * set env variable GRADIO_TEMP_DIR to args.tmp_dir
+ os.environ["GRADIO_TEMP_DIR"] = args.tmp_dir
+
+ model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv = init_model(args)
+ start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv)
diff --git a/ThirdParty/PointLLM/pointllm/eval/eval_modelnet_cls.py b/ThirdParty/PointLLM/pointllm/eval/eval_modelnet_cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..65752c03a8be53a2269dbc948a96229e0176e6f0
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/eval/eval_modelnet_cls.py
@@ -0,0 +1,195 @@
+import argparse
+import torch
+from torch.utils.data import DataLoader
+import os
+from pointllm.conversation import conv_templates, SeparatorStyle
+from pointllm.utils import disable_torch_init
+from pointllm.model.utils import KeywordsStoppingCriteria
+from pointllm.model import PointLLMLlamaForCausalLM
+from pointllm.data import ModelNet
+from tqdm import tqdm
+from pointllm.eval.evaluator import start_evaluation
+from transformers import AutoTokenizer
+
+import os
+import json
+
+PROMPT_LISTS = [
+ "What is this?",
+ "This is an object of "
+]
+
+def init_model(args):
+ # Model
+ disable_torch_init()
+ model_name = os.path.expanduser(args.model_name)
+
+ # * print the model_name (get the basename)
+ print(f'[INFO] Model name: {os.path.basename(model_name)}')
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ model = PointLLMLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=False, use_cache=True, torch_dtype=torch.bfloat16).cuda()
+ model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer)
+
+ conv_mode = "vicuna_v1_1"
+
+ conv = conv_templates[conv_mode].copy()
+
+ return model, tokenizer, conv
+
+def load_dataset(config_path, split, subset_nums, use_color):
+ print(f"Loading {split} split of ModelNet datasets.")
+ dataset = ModelNet(config_path=config_path, split=split, subset_nums=subset_nums, use_color=use_color)
+ print("Done!")
+ return dataset
+
+def get_dataloader(dataset, batch_size, shuffle=False, num_workers=4):
+ assert shuffle is False, "Since we using the index of ModelNet as Object ID when evaluation \
+ so shuffle shoudl be False and should always set random seed."
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
+ return dataloader
+
+def generate_outputs(model, tokenizer, input_ids, point_clouds, stopping_criteria, do_sample=True, temperature=1.0, top_k=50, max_length=2048, top_p=0.95):
+ model.eval()
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ point_clouds=point_clouds,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_k=top_k,
+ max_length=max_length,
+ top_p=top_p,
+ stopping_criteria=[stopping_criteria]) # * B, L'
+
+ input_token_len = input_ids.shape[1]
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
+ if n_diff_input_output > 0:
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
+ outputs = [output.strip() for output in outputs]
+
+ return outputs
+
+def start_generation(model, tokenizer, conv, dataloader, prompt_index, output_dir, output_file):
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
+ qs = PROMPT_LISTS[prompt_index]
+
+ results = {"prompt": qs}
+
+ point_backbone_config = model.get_model().point_backbone_config
+ point_token_len = point_backbone_config['point_token_len']
+ default_point_patch_token = point_backbone_config['default_point_patch_token']
+ default_point_start_token = point_backbone_config['default_point_start_token']
+ default_point_end_token = point_backbone_config['default_point_end_token']
+ mm_use_point_start_end = point_backbone_config['mm_use_point_start_end']
+
+ if mm_use_point_start_end:
+ qs = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + qs
+ else:
+ qs = default_point_patch_token * point_token_len + '\n' + qs
+
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+
+ prompt = conv.get_prompt()
+ inputs = tokenizer([prompt])
+
+ input_ids_ = torch.as_tensor(inputs.input_ids).cuda() # * tensor of 1, L
+
+ stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids_)
+
+ responses = []
+
+ for batch in tqdm(dataloader):
+ point_clouds = batch["point_clouds"].cuda().to(model.dtype) # * tensor of B, N, C(3)
+ labels = batch["labels"]
+ label_names = batch["label_names"]
+ indice = batch["indice"]
+
+ batchsize = point_clouds.shape[0]
+
+ input_ids = input_ids_.repeat(batchsize, 1) # * tensor of B, L
+
+ outputs = generate_outputs(model, tokenizer, input_ids, point_clouds, stopping_criteria) # List of str, length is B
+
+ # saving results
+ for index, output, label, label_name in zip(indice, outputs, labels, label_names):
+ responses.append({
+ "object_id": index.item(),
+ "ground_truth": label.item(),
+ "model_output": output,
+ "label_name": label_name
+ })
+
+ results["results"] = responses
+
+ os.makedirs(output_dir, exist_ok=True)
+ # save the results to a JSON file
+ with open(os.path.join(output_dir, output_file), 'w') as fp:
+ json.dump(results, fp, indent=2)
+
+ # * print info
+ print(f"Saved results to {os.path.join(output_dir, output_file)}")
+
+ return results
+
+def main(args):
+ # * ouptut
+ args.output_dir = os.path.join(args.model_name, "evaluation")
+
+ # * output file
+ args.output_file = f"ModelNet_classification_prompt{args.prompt_index}.json"
+ args.output_file_path = os.path.join(args.output_dir, args.output_file)
+
+ # * First inferencing, then evaluate
+ if not os.path.exists(args.output_file_path):
+ # * need to generate results first
+ dataset = load_dataset(config_path=None, split=args.split, subset_nums=args.subset_nums, use_color=args.use_color) # * defalut config
+ dataloader = get_dataloader(dataset, args.batch_size, args.shuffle, args.num_workers)
+
+ model, tokenizer, conv = init_model(args)
+
+ # * ouptut
+ print(f'[INFO] Start generating results for {args.output_file}.')
+ results = start_generation(model, tokenizer, conv, dataloader, args.prompt_index, args.output_dir, args.output_file)
+
+ # * release model and tokenizer, and release cuda memory
+ del model
+ del tokenizer
+ torch.cuda.empty_cache()
+ else:
+ # * directly load the results
+ print(f'[INFO] {args.output_file_path} already exists, directly loading...')
+ with open(args.output_file_path, 'r') as fp:
+ results = json.load(fp)
+
+ # * evaluation file
+ evaluated_output_file = args.output_file.replace(".json", f"_evaluated_{args.gpt_type}.json")
+ # * start evaluation
+ if args.start_eval:
+ start_evaluation(results, output_dir=args.output_dir, output_file=evaluated_output_file, eval_type="modelnet-close-set-classification", model_type=args.gpt_type, parallel=True, num_workers=20)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_name", type=str, \
+ default="RunsenXu/PointLLM_7B_v1.2")
+
+ # * dataset type
+ parser.add_argument("--split", type=str, default="test", help="train or test.")
+ parser.add_argument("--use_color", action="store_true", default=True)
+
+ # * data loader, batch_size, shuffle, num_workers
+ parser.add_argument("--batch_size", type=int, default=30)
+ parser.add_argument("--shuffle", type=bool, default=False)
+ parser.add_argument("--num_workers", type=int, default=20)
+ parser.add_argument("--subset_nums", type=int, default=-1) # * only use "subset_nums" of samples, mainly for debug
+
+ # * evaluation setting
+ parser.add_argument("--prompt_index", type=int, default=0)
+ parser.add_argument("--start_eval", action="store_true", default=False)
+ parser.add_argument("--gpt_type", type=str, default="gpt-3.5-turbo-0613", choices=["gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-4-0613", "gpt-4-1106-preview"], help="Type of the model used to evaluate.")
+
+ args = parser.parse_args()
+
+ main(args)
diff --git a/ThirdParty/PointLLM/pointllm/eval/eval_objaverse.py b/ThirdParty/PointLLM/pointllm/eval/eval_objaverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..c92b2656220986bbc97181726d04735698636eac
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/eval/eval_objaverse.py
@@ -0,0 +1,220 @@
+import argparse
+import torch
+from torch.utils.data import DataLoader
+import os
+from pointllm.conversation import conv_templates, SeparatorStyle
+from pointllm.utils import disable_torch_init
+from pointllm.model import *
+from pointllm.model.utils import KeywordsStoppingCriteria
+from pointllm.data import ObjectPointCloudDataset
+from tqdm import tqdm
+from transformers import AutoTokenizer
+from pointllm.eval.evaluator import start_evaluation
+
+import os
+import json
+
+PROMPT_LISTS = [
+ "What is this?",
+ "This is an object of ",
+ "Caption this 3D model in detail."
+]
+
+def init_model(args):
+ # Model
+ disable_torch_init()
+ model_name = os.path.expanduser(args.model_name)
+
+ # * print the model_name (get the basename)
+ print(f'[INFO] Model name: {os.path.basename(model_name)}')
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ model = PointLLMLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=False, use_cache=True, torch_dtype=torch.bfloat16).cuda()
+ model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer)
+
+ conv_mode = "vicuna_v1_1"
+
+ conv = conv_templates[conv_mode].copy()
+
+ return model, tokenizer, conv
+
+def load_dataset(data_path, anno_path, pointnum, conversation_types, use_color):
+ print("Loading validation datasets.")
+ dataset = ObjectPointCloudDataset(
+ data_path=data_path,
+ anno_path=anno_path,
+ pointnum=pointnum,
+ conversation_types=conversation_types,
+ use_color=use_color,
+ tokenizer=None # * load point cloud only
+ )
+ print("Done!")
+ return dataset
+
+def get_dataloader(dataset, batch_size, shuffle=False, num_workers=4):
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
+ return dataloader
+
+def generate_outputs(model, tokenizer, input_ids, point_clouds, stopping_criteria, do_sample=True, temperature=1.0, top_k=50, max_length=2048, top_p=0.95):
+ model.eval()
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ point_clouds=point_clouds,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_k=top_k,
+ max_length=max_length,
+ top_p=top_p,
+ stopping_criteria=[stopping_criteria]) # * B, L'
+
+ input_token_len = input_ids.shape[1]
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
+ if n_diff_input_output > 0:
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
+ outputs = [output.strip() for output in outputs]
+
+ return outputs
+
+def start_generation(model, tokenizer, conv, dataloader, annos, prompt_index, output_dir, output_file):
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
+ qs = PROMPT_LISTS[prompt_index]
+
+ results = {"prompt": qs}
+
+ point_backbone_config = model.get_model().point_backbone_config
+ point_token_len = point_backbone_config['point_token_len']
+ default_point_patch_token = point_backbone_config['default_point_patch_token']
+ default_point_start_token = point_backbone_config['default_point_start_token']
+ default_point_end_token = point_backbone_config['default_point_end_token']
+ mm_use_point_start_end = point_backbone_config['mm_use_point_start_end']
+
+ if mm_use_point_start_end:
+ qs = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + qs
+ else:
+ qs = default_point_patch_token * point_token_len + '\n' + qs
+
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+
+ prompt = conv.get_prompt()
+ inputs = tokenizer([prompt])
+
+ input_ids_ = torch.as_tensor(inputs.input_ids).cuda() # * tensor of 1, L
+
+ stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids_)
+
+ responses = []
+
+ for batch in tqdm(dataloader):
+ point_clouds = batch["point_clouds"].cuda().to(model.dtype) # * tensor of B, N, C(3)
+ object_ids = batch["object_ids"] # * list of string
+
+ batchsize = len(object_ids)
+
+ input_ids = input_ids_.repeat(batchsize, 1) # * tensor of B, L
+
+ outputs = generate_outputs(model, tokenizer, input_ids, point_clouds, stopping_criteria) # List of str, length is B
+
+ # saving results
+ for obj_id, output in zip(object_ids, outputs):
+ responses.append({
+ "object_id": obj_id,
+ "ground_truth": annos[obj_id],
+ "model_output": output
+ })
+
+ results["results"] = responses
+
+ os.makedirs(output_dir, exist_ok=True)
+ # save the results to a JSON file
+ with open(os.path.join(output_dir, output_file), 'w') as fp:
+ json.dump(results, fp, indent=2)
+
+ # * print info
+ print(f"Saved results to {os.path.join(output_dir, output_file)}")
+
+ return results
+
+def main(args):
+ # * ouptut
+ args.output_dir = os.path.join(args.model_name, "evaluation")
+
+ # * output file
+ anno_file = os.path.splitext(os.path.basename(args.anno_path))[0]
+ args.output_file = f"{anno_file}_Objaverse_{args.task_type}_prompt{args.prompt_index}.json"
+ args.output_file_path = os.path.join(args.output_dir, args.output_file)
+
+ # * First inferencing, then evaluate
+ if not os.path.exists(args.output_file_path):
+ # * need inferencing
+ # * load annotation files
+ with open(args.anno_path, 'r') as fp:
+ annos = json.load(fp)
+
+ dataset = load_dataset(args.data_path, args.anno_path, args.pointnum, ("simple_description",), args.use_color)
+ dataloader = get_dataloader(dataset, args.batch_size, args.shuffle, args.num_workers)
+
+ model, tokenizer, conv = init_model(args)
+
+ # * convert annos file from [{"object_id": }] to {"object_id": }
+ annos = {anno["object_id"]: anno["conversations"][1]['value'] for anno in annos}
+
+ print(f'[INFO] Start generating results for {args.output_file}.')
+ results = start_generation(model, tokenizer, conv, dataloader, annos, args.prompt_index, args.output_dir, args.output_file)
+
+ # * release model and tokenizer, and release cuda memory
+ del model
+ del tokenizer
+ torch.cuda.empty_cache()
+ else:
+ # * directly load the results
+ print(f'[INFO] {args.output_file_path} already exists, directly loading...')
+ with open(args.output_file_path, 'r') as fp:
+ results = json.load(fp)
+
+ if args.start_eval:
+ evaluated_output_file = args.output_file.replace(".json", f"_evaluated_{args.gpt_type}.json")
+ eval_type_mapping = {
+ "captioning": "object-captioning",
+ "classification": "open-free-form-classification"
+ }
+ start_evaluation(results, output_dir=args.output_dir, output_file=evaluated_output_file, eval_type=eval_type_mapping[args.task_type], model_type=args.gpt_type, parallel=True, num_workers=20)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_name", type=str, \
+ default="RunsenXu/PointLLM_7B_v1.2")
+
+ # * dataset type
+ parser.add_argument("--data_path", type=str, default="data/objaverse_data", required=False)
+ parser.add_argument("--anno_path", type=str, default="data/anno_data/PointLLM_brief_description_val_200_GT.json", required=False)
+ parser.add_argument("--pointnum", type=int, default=8192)
+ parser.add_argument("--use_color", action="store_true", default=True)
+
+ # * data loader, batch_size, shuffle, num_workers
+ parser.add_argument("--batch_size", type=int, default=6)
+ parser.add_argument("--shuffle", type=bool, default=False)
+ parser.add_argument("--num_workers", type=int, default=10)
+
+ # * evaluation setting
+ parser.add_argument("--prompt_index", type=int, default=0)
+ parser.add_argument("--start_eval", action="store_true", default=False)
+ parser.add_argument("--gpt_type", type=str, default="gpt-4-0613", choices=["gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-4-0613", "gpt-4-1106-preview"], help="Type of the model used to evaluate.")
+ parser.add_argument("--task_type", type=str, default="captioning", choices=["captioning", "classification"], help="Type of the task to evaluate.")
+
+ args = parser.parse_args()
+
+ # * check prompt index
+ # * * classification: 0, 1 and captioning: 2. Raise Warning otherwise.
+ if args.task_type == "classification":
+ if args.prompt_index != 0 and args.prompt_index != 1:
+ print("[Warning] For classification task, prompt_index should be 0 or 1.")
+ elif args.task_type == "captioning":
+ if args.prompt_index != 2:
+ print("[Warning] For captioning task, prompt_index should be 2.")
+ else:
+ raise NotImplementedError
+
+ main(args)
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/eval/evaluator.py b/ThirdParty/PointLLM/pointllm/eval/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a723589ba605e0eb823e2fe77b0777bfb17f5e8
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/eval/evaluator.py
@@ -0,0 +1,843 @@
+import argparse
+import json
+import os
+from utils import OpenAIGPT
+from tqdm import tqdm
+from multiprocessing import Pool
+import random
+random.seed(0)
+import re
+
+gpt4_open_free_from_cls_prompt = """Analyze two sentences and determine if they're referring to the same general object or concept, focusing on the type of object, not attributes such as color, size, or shape. Respond with 'T' if they refer to the same thing and 'F' if not. Also, provide a brief rationale (no more than 20 words) for your judgment.
+Example:
+Input: 1. Spiral staircase that goes from a ground floor. 2. This is a 3D model of wooden stairs in light brown
+Output: T#Both refer to a staircase.
+
+Now, analyze the following:
+Input: 1. {ground_truth} 2. {model_output}
+Output: """ # * about 230 input tokens
+
+chatgpt_close_set_cls_prompt = """Given the following free-form description of a 3D object, please determine the most probable class index from the following 40 available categories, even if the description doesn't clearly refer to any one of them. Make your best-educated guess based on the information provided. If the description already contains a valid index, then the index should be selected. If it contains more than one valid index, then randomly select one index (specify your reason). If there is no valid index and it cannot be inferred from the information, return '-1#NA#Cannot infer'.
+Categories:
+{candidate_lists}
+Reply with the format of 'index#class#short reason (no more than 10 words)'.
+
+Examples:
+Input: This is a 3D object model of a cartoon white truck.
+Output: 7#car#Closest match to 'car' in categories.
+
+Input: A green leaf in a flower pot.
+Output: 26#plant#The primary subject 'leaf' directly indicates a plant.
+
+Input: It's difficult to determine the exact type of this object due to insufficient details. But it seems to be like a piece of furniture.
+Output: 33#table#Randomly select one kind of furniture from the list.
+
+Input: I cannot determine the specific type of the object without additional information or context.
+Output: -1#NA#Cannot infer.
+
+Now analyze the following:
+Input: """
+
+gpt4_object_captioning_prompt = """Evaluate a model-generated caption against a human-generated caption (ground truth) for a 3D model. Identify the aspects mentioned in the human caption and calculate the percentage of these aspects correctly mentioned or partially matched in the model caption. Score from 0 to 100, where each aspect contributes equally to the score. Consider similar concepts for partial score.
+
+Provide your score (0-100) and a short justification (less than 15 words) in the format of 'score#reason'
+
+Example:
+Human: A white brown skeleton
+Model: This is a 3D model of a small, cartoon-like robot. It has a spherical body and is covered in a layer of white dust.
+Output: 50#mention white; skeleton and robot have similar appearence.
+
+Now score the following:
+Human: {ground_truth}
+Model: {model_output}
+Output: """
+
+chatgpt_object_captioning_prompt = gpt4_object_captioning_prompt
+chatgpt_open_free_from_cls_prompt = gpt4_open_free_from_cls_prompt
+gpt4_close_set_cls_prompt = chatgpt_close_set_cls_prompt
+
+GPT_PRICES = {
+ # * check https://openai.com/pricing for updated price
+ "gpt-3.5-turbo-0613": {
+ "price_1k_prompt_tokens": 0.0015,
+ "price_1k_completion_tokens": 0.002
+ },
+ "gpt-3.5-turbo-1106": {
+ "price_1k_prompt_tokens": 0.0010,
+ "price_1k_completion_tokens": 0.002
+ },
+ "gpt-4-0613":{
+ "price_1k_prompt_tokens": 0.03,
+ "price_1k_completion_tokens": 0.06
+ },
+ "gpt-4-1106-preview":{
+ "price_1k_prompt_tokens": 0.01,
+ "price_1k_completion_tokens": 0.03
+ }
+}
+
+class OpenAIOpenFreeFormClsEvaluator():
+ def __init__(self, inputs, output_dir, output_file, model_type="gpt-4-0613"):
+ """
+ Args:
+ inputs: A dictionary containing the results of the evaluation. It contains two keys: "results" and "prompt".
+ "prompt": str
+ "results": [
+ {
+ "object_id": str,
+ "model_output": str,
+ "ground_truth": str
+ }
+ ]
+ """
+ print("-" * 80)
+ print("Initializing OpenAIEvaluator...")
+ self.results = inputs['results']# * contains two keys: "results" and "prompt"
+ self.inference_prompt = inputs['prompt'] # * used to prompt PointLLM
+ self.correct_predictions = 0
+ self.total_predictions = 0
+ self.invalid_responses = 0
+ self.response_data = [] # to save all the response data by openaigpt
+ self.model_type = model_type
+ self.check_model_type()
+
+ self.prompt_tokens = 0
+ self.completion_tokens = 0
+
+ self.default_chat_parameters = {
+ "model": model_type,
+ "temperature": 1,
+ "top_p": 1,
+ "max_tokens": 2048
+ }
+
+ # * price
+ self.price_1k_prompt_tokens = GPT_PRICES[model_type]["price_1k_prompt_tokens"]
+ self.price_1k_completion_tokens = GPT_PRICES[model_type]["price_1k_completion_tokens"]
+
+ print(f"OpenAIGPT config: ")
+ print(self.default_chat_parameters)
+
+ self.openaigpt = OpenAIGPT(**self.default_chat_parameters)
+ self.gpt_prompt = chatgpt_open_free_from_cls_prompt if "gpt-3.5" in model_type else gpt4_open_free_from_cls_prompt
+ self.output_dir = output_dir
+ self.output_file = output_file
+ self.temp_output_file = self.output_file.replace(".json", "_processed_temp.json")
+
+ def check_model_type(self):
+ # * warning if not using gpt-4, recommend using gpt-4 for this task
+ if "gpt-4" not in self.model_type:
+ print(f"[WARNING] You are using {self.model_type} for evaluation. We recommend using gpt-4 for this task.")
+
+ def resume_processing(self):
+ processed_results_path = os.path.join(self.output_dir, self.temp_output_file)
+ if os.path.exists(processed_results_path):
+ print("-" * 80)
+ # * print resuming
+ print(f"Resuming processing...")
+ print(f"Loading processed results from {processed_results_path}...")
+ with open(processed_results_path, "r") as f:
+ saved_results = json.load(f)
+ self.correct_predictions = saved_results["correct_predictions"]
+ self.total_predictions = saved_results["total_predictions"]
+ self.invalid_responses = saved_results["invalid_responses"]
+ self.response_data = saved_results["results"]
+ self.prompt_tokens = saved_results["prompt_tokens"]
+ self.completion_tokens = saved_results["completion_tokens"]
+
+ print(f"Processed results: {len(self.response_data)}")
+ # * print the length of all the data
+ print(f"Total results: {len(self.results)}")
+
+ # * remove processed data
+ processed_ids = [d['object_id'] for d in self.response_data]
+ self.results = [r for r in self.results if r['object_id'] not in processed_ids]
+
+ print(f"Remaining results: {len(self.results)}")
+
+ def remove_temp_file(self):
+ processed_results_path = os.path.join(self.output_dir, self.temp_output_file)
+ if os.path.exists(processed_results_path):
+ os.remove(processed_results_path)
+ print("-" * 80)
+ print(f"Removed Temporary file {processed_results_path}")
+
+ def parse_gpt_response_evaluate(self, gpt_response):
+ gpt_response = gpt_response.strip()
+
+ cls_result = gpt_response[0].upper()
+ reason = gpt_response[2:] if len(gpt_response) > 2 else ""
+
+ if cls_result not in ['T', 'F']:
+ self.invalid_responses += 1
+ return 0, "INVALID", gpt_response
+
+ accuracy = 1 if cls_result == 'T' else 0
+
+ return accuracy, cls_result, reason
+
+ def evaluate_result(self, result):
+ object_id = result['object_id']
+ ground_truth = result['ground_truth']
+ model_output = result['model_output']
+ messages = [{"role": "user", "content": self.gpt_prompt.format(ground_truth=ground_truth, model_output=model_output)}]
+
+ gpt_response = self.openaigpt.safe_chat_complete(messages, content_only=False)
+
+ prompt_tokens = gpt_response['usage']['prompt_tokens']
+ completion_tokens = gpt_response['usage']['completion_tokens']
+
+ gpt_response = gpt_response['choices'][0]["message"]['content']
+
+
+ accuracy, cls_result, reason = self.parse_gpt_response_evaluate(gpt_response) # return 0, "INVALID", gpt_response if not valid
+
+ return object_id, model_output, ground_truth, accuracy, cls_result, reason, prompt_tokens, completion_tokens
+
+ def evaluate(self):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting single-thread evaluation...")
+ results = self.results
+
+ try:
+ for result in tqdm(results):
+ object_id, model_output, ground_truth, accuracy, cls_result, reason, prompt_tokens, completion_tokens = self.evaluate_result(result)
+ self.correct_predictions += accuracy
+ self.total_predictions += 1
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'model_output': model_output,
+ 'gpt_cls_result': cls_result,
+ 'gpt_reason': reason
+ })
+
+ print("Evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ self.save_results(is_temp=True)
+ exit()
+
+ def parallel_evaluate(self, num_workers=20):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting parallel evaluation...")
+ results = self.results
+
+ try:
+ with Pool(num_workers) as pool:
+ with tqdm(total=len(results)) as pbar: # create a progress bar
+ for object_id, model_output, ground_truth, accuracy, cls_result, reason, prompt_tokens, completion_tokens in pool.imap_unordered(self.evaluate_result, results):
+ self.correct_predictions += accuracy
+ self.total_predictions += 1
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ if cls_result == 'INVALID':
+ self.invalid_responses += 1
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'model_output': model_output,
+ 'gpt_cls_result': cls_result,
+ 'gpt_reason': reason
+ })
+
+ pbar.update() # update the progress bar
+
+ print("Parallel evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ self.save_results(is_temp=True)
+ exit()
+
+ def save_results(self, is_temp=False):
+ if is_temp:
+ output_path = os.path.join(self.output_dir, self.temp_output_file)
+ else:
+ output_path = os.path.join(self.output_dir, self.output_file)
+ if self.total_predictions - self.invalid_responses == 0:
+ accuracy = 0 # * no results and get error
+ else:
+ accuracy = self.correct_predictions / (self.total_predictions - self.invalid_responses) * 100
+ with open(output_path, 'w') as f:
+ results_to_save = {
+ 'inference_prompt': self.inference_prompt,
+ 'prompt': self.gpt_prompt,
+ 'accuracy': f"{accuracy:.2f}%",
+ 'total_predictions': self.total_predictions,
+ 'correct_predictions': self.correct_predictions,
+ 'invalid_responses': self.invalid_responses,
+ 'prompt_tokens': self.prompt_tokens,
+ 'completion_tokens': self.completion_tokens,
+ 'GPT_cost': self.get_costs(),
+ 'results': self.response_data,
+ }
+ json.dump(results_to_save, f, indent=2)
+
+ print(f"Results saved to {output_path}")
+ # * print the length of saved results
+ print(f"Saved {len(self.response_data)} results in total.")
+
+ def print_results(self):
+ print('-' * 80)
+ if self.total_predictions - self.invalid_responses == 0:
+ accuracy = 0 # * no results and get error
+ else:
+ accuracy = self.correct_predictions / (self.total_predictions - self.invalid_responses) * 100
+ print("Results:")
+ print(f"Accuracy: {accuracy:.2f}%")
+ print(f"Total Predictions: {self.total_predictions}")
+ print(f"Correct Predictions: {self.correct_predictions}")
+ print(f"Invalid Responses: {self.invalid_responses}")
+ self.print_costs()
+
+ def print_costs(self):
+ print(f"Prompt Tokens Price: {self.prompt_tokens * self.price_1k_prompt_tokens / 1000:.2f} USD")
+ print(f"Completion Tokens Price: {self.completion_tokens * self.price_1k_completion_tokens / 1000:.2f} USD")
+
+ def get_costs(self):
+ return self.prompt_tokens * self.price_1k_prompt_tokens / 1000 + self.completion_tokens * self.price_1k_completion_tokens / 1000
+
+
+class OpenAICloseSetClsEvaluator(OpenAIOpenFreeFormClsEvaluator):
+ def __init__(self, inputs, output_dir, output_file, model_type="gpt-3.5-turbo-0613"):
+ super().__init__(inputs, output_dir, output_file, model_type)
+ self.gpt_prompt = chatgpt_close_set_cls_prompt if "gpt-3.5" in model_type else gpt4_close_set_cls_prompt
+
+ self.invalid_correct_predictions = 0 # * random choice and correct coincidently
+
+ # * import category names
+ try:
+ # * load a txt files of category names
+ catfile = os.path.join(os.path.dirname(__file__), '../data/modelnet_config/modelnet40_shape_names_modified.txt') # * i.e. pointllm/data/modelnet_config/modelnet40_shape_names_modified.txt
+ self.candidate_lists_names = [line.strip() for line in open(catfile)] # * list of category names
+ except:
+ print(f"Current categories file is {catfile}. Need to move the category file to pointllm/eval/configs/.")
+
+ # * make the prompt
+ candidate_lists = [f'{i}: {cat}' for i, cat in enumerate(self.candidate_lists_names)]
+ self.num_categories = len(candidate_lists)
+ self.candidate_lists = '\n'.join(candidate_lists)
+ self.gpt_prompt = self.gpt_prompt.format(num_categories=self.num_categories, candidate_lists=self.candidate_lists) + "{model_output}\nOutput: "
+
+ def check_model_type(self):
+ # * no need to check for this task
+ return
+
+ def resume_processing(self):
+ processed_results_path = os.path.join(self.output_dir, self.temp_output_file)
+ if os.path.exists(processed_results_path):
+ print("-" * 80)
+ # * print resuming
+ print(f"Resuming processing...")
+ print(f"Loading processed results from {processed_results_path}...")
+ with open(processed_results_path, "r") as f:
+ saved_results = json.load(f)
+ self.correct_predictions = saved_results["correct_predictions"]
+ self.total_predictions = saved_results["total_predictions"]
+ self.invalid_responses = saved_results["invalid_responses"]
+ self.invalid_correct_predictions = saved_results["invalid_correct_predictions"]
+ self.response_data = saved_results["results"]
+ self.prompt_tokens = saved_results["prompt_tokens"]
+ self.completion_tokens = saved_results["completion_tokens"]
+
+ print(f"Processed results: {len(self.response_data)}")
+ # * print the length of all the data
+ print(f"Total results: {len(self.results)}")
+
+ # * remove processed data
+ processed_ids = [d['object_id'] for d in self.response_data]
+ self.results = [r for r in self.results if r['object_id'] not in processed_ids]
+
+ print(f"Remaining results: {len(self.results)}")
+
+ def parse_gpt_response_evaluate(self, gpt_response, ground_truth):
+ """
+ Argument:
+ gpt_response: str, index#label#short_reason
+ groud_truth: int
+ """
+
+ # * use regular expression to extract
+ pattern = r'(\d+#[^#]*#.*$)'
+ match = re.search(pattern, gpt_response)
+
+ gpt_response = match.group(1) if match else gpt_response
+
+ gpt_response = gpt_response.strip()
+ gpt_response_list = gpt_response.split('#')
+
+ cls_result = gpt_response_list[0]
+ cls_label = gpt_response_list[1] if len(gpt_response_list) > 1 else ""
+ reason = gpt_response_list[2] if len(gpt_response_list) > 2 else ""
+
+ try:
+ # * convert to int
+ cls_result = int(cls_result)
+ if cls_result not in range(self.num_categories) or cls_label == "NA":
+ # * not valid range
+ cls_result = -1
+ except ValueError:
+ print(f"Error: unale to parse {gpt_response}.")
+ cls_result = -1
+
+ if cls_result == -1:
+ # * random choose one index from 0 to self.num_categories
+ cls_result = random.choice(range(self.num_categories))
+ cls_label = "INVALID"
+ reason = gpt_response
+
+ self.invalid_responses += 1
+
+ accuracy = 1 if cls_result == ground_truth else 0
+
+ return accuracy, cls_result, cls_label, reason
+
+ def evaluate_result(self, result):
+ object_id = result.get('object_id', -1)
+ ground_truth = result['ground_truth']
+ ground_truth_label = result['label_name']
+ model_output = result['model_output']
+
+ messages = [{"role": "user", "content": self.gpt_prompt.format(model_output=model_output)}]
+
+ gpt_response = self.openaigpt.safe_chat_complete(messages, content_only=False)
+
+ prompt_tokens = gpt_response['usage']['prompt_tokens']
+ completion_tokens = gpt_response['usage']['completion_tokens']
+
+ gpt_response = gpt_response['choices'][0]["message"]['content']
+
+ accuracy, cls_result, cls_label, reason = self.parse_gpt_response_evaluate(gpt_response, ground_truth) # return 0, "INVALID", gpt_response if not valid
+
+ return object_id, model_output, ground_truth, accuracy, cls_result, cls_label, reason, ground_truth_label, prompt_tokens, completion_tokens
+
+ def evaluate(self):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting single-thread evaluation...")
+ results = self.results
+
+ try:
+ for result in tqdm(results):
+ object_id, model_output, ground_truth, accuracy, cls_result, cls_label, reason, ground_truth_label, prompt_tokens, completion_tokens = self.evaluate_result(result)
+ self.correct_predictions += accuracy
+ self.total_predictions += 1
+
+ if cls_label == "INVALID":
+ self.invalid_correct_predictions += accuracy
+ self.invalid_responses += 1
+
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'gpt_cls_result': cls_result,
+ 'ground_truth_label': ground_truth_label,
+ 'gpt_cls_label': cls_label,
+ 'model_output': model_output,
+ 'gpt_reason': reason,
+ 'prompt_tokens': prompt_tokens,
+ 'completion_tokens': completion_tokens
+ })
+
+ print("Evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ print(f"Current sample is {result}.")
+ self.save_results(is_temp=True)
+ exit()
+
+ def parallel_evaluate(self, num_workers=20):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting parallel evaluation...")
+ results = self.results
+
+ try:
+ with Pool(num_workers) as pool:
+ with tqdm(total=len(results)) as pbar: # create a progress bar
+ for object_id, model_output, ground_truth, accuracy, cls_result, cls_label, reason, ground_truth_label, prompt_tokens, completion_tokens in pool.imap_unordered(self.evaluate_result, results):
+ self.correct_predictions += accuracy
+ self.total_predictions += 1
+
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ if cls_label == "INVALID":
+ self.invalid_correct_predictions += accuracy
+ self.invalid_responses += 1
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'gpt_cls_result': cls_result,
+ 'ground_truth_label': ground_truth_label,
+ 'gpt_cls_label': cls_label,
+ 'model_output': model_output,
+ 'gpt_reason': reason,
+ 'prompt_tokens': prompt_tokens,
+ 'completion_tokens': completion_tokens
+ })
+
+ pbar.update() # update the progress bar
+
+ print("Parallel evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ self.save_results(is_temp=True)
+ exit()
+
+ def save_results(self, is_temp=False):
+ if is_temp:
+ output_path = os.path.join(self.output_dir, self.temp_output_file)
+ else:
+ output_path = os.path.join(self.output_dir, self.output_file)
+ if self.total_predictions - self.invalid_responses == 0:
+ accuracy = 0 # * no results and get error
+ clean_accuracy = 0
+ else:
+ accuracy = self.correct_predictions / self.total_predictions * 100
+ clean_accuracy = (self.correct_predictions - self.invalid_correct_predictions) / (self.total_predictions - self.invalid_responses) * 100
+ with open(output_path, 'w') as f:
+ results_to_save = {
+ 'inference_prompt': self.inference_prompt,
+ 'prompt': self.gpt_prompt,
+ 'accuracy': f"{accuracy:.2f}%",
+ 'clean_accuracy': f"{clean_accuracy:.2f}%",
+ 'total_predictions': self.total_predictions,
+ 'correct_predictions': self.correct_predictions,
+ 'invalid_correct_predictions': self.invalid_correct_predictions,
+ 'invalid_responses': self.invalid_responses,
+ 'prompt_tokens': self.prompt_tokens,
+ 'completion_tokens': self.completion_tokens,
+ 'GPT_cost': self.get_costs(),
+ 'results': self.response_data,
+ }
+ json.dump(results_to_save, f, indent=2)
+
+ print(f"Results saved to {output_path}")
+ # * print the length of saved results
+ print(f"Saved {len(self.response_data)} results in total.")
+
+ def print_results(self):
+ print('-' * 80)
+ if self.total_predictions - self.invalid_responses == 0:
+ accuracy = 0 # * no results and get error
+ else:
+ accuracy = self.correct_predictions / self.total_predictions * 100
+ clean_accuracy = (self.correct_predictions - self.invalid_correct_predictions) / (self.total_predictions - self.invalid_responses) * 100
+ accuracy = self.correct_predictions / self.total_predictions * 100
+ print("Results:")
+ print(f"Accuracy: {accuracy:.2f}%")
+ print(f"Clean Accuracy: {clean_accuracy:.2f}%",)
+ print(f"Total Predictions: {self.total_predictions}")
+ print(f"Correct Predictions: {self.correct_predictions}")
+ print(f"Invalid Correct Predictions: {self.invalid_correct_predictions}")
+ print(f"Invalid Responses: {self.invalid_responses}")
+ print(f"Prompt Tokens: {self.prompt_tokens}")
+ print(f"Completion Tokens: {self.completion_tokens}")
+
+ self.print_costs()
+
+class OpenAIObjectCaptioningEvaluator(OpenAIOpenFreeFormClsEvaluator):
+ def __init__(self, inputs, output_dir, output_file, model_type="gpt-4-0613"):
+ super().__init__(inputs, output_dir, output_file, model_type)
+ self.gpt_prompt = chatgpt_object_captioning_prompt if "gpt-3.5" in model_type else gpt4_object_captioning_prompt
+
+ self.total_scores = 0
+
+ def resume_processing(self):
+ processed_results_path = os.path.join(self.output_dir, self.temp_output_file)
+ if os.path.exists(processed_results_path):
+ print("-" * 80)
+ # * print resuming
+ print(f"Resuming processing...")
+ print(f"Loading processed results from {processed_results_path}...")
+ with open(processed_results_path, "r") as f:
+ saved_results = json.load(f)
+ self.total_scores = float(saved_results["total_score"])
+
+ self.total_predictions = saved_results["total_predictions"]
+ self.invalid_responses = saved_results["invalid_responses"]
+ self.response_data = saved_results["results"]
+ self.prompt_tokens = saved_results["prompt_tokens"]
+ self.completion_tokens = saved_results["completion_tokens"]
+
+ print(f"Processed results: {len(self.response_data)}")
+ # * print the length of all the data
+ print(f"Total results: {len(self.results)}")
+
+ # * remove processed data
+ processed_ids = [d['object_id'] for d in self.response_data]
+ self.results = [r for r in self.results if r['object_id'] not in processed_ids]
+
+ print(f"Remaining results: {len(self.results)}")
+
+ def parse_gpt_response_evaluate(self, gpt_response, ground_truth):
+ """
+ Argument:
+ gpt_response: str, index#label#short_reason
+ groud_truth: int
+ """
+
+ # * use regular expression to extract
+ pattern = r'(\d*#.*)'
+ match = re.search(pattern, gpt_response)
+
+ gpt_response = match.group(1) if match else gpt_response
+
+ gpt_response = gpt_response.strip()
+ gpt_response_list = gpt_response.split('#')
+
+ gpt_score = gpt_response_list[0]
+ reason = gpt_response_list[1] if len(gpt_response_list) > 1 else ""
+
+ try:
+ # * convert to int
+ gpt_score = int(gpt_score)
+ if gpt_score not in range(101): # * in 0-100
+ # * not valid range
+ gpt_score = -1
+ except ValueError:
+ print(f"Error: unale to parse {gpt_response}.")
+ gpt_score = -1
+
+ if gpt_score == -1:
+ reason = gpt_response
+
+ return gpt_score, reason
+
+ def evaluate_result(self, result):
+ object_id = result.get('object_id', -1)
+ ground_truth = result['ground_truth']
+ model_output = result['model_output']
+
+ messages = [{"role": "user", "content": self.gpt_prompt.format(ground_truth=ground_truth, model_output=model_output)}]
+
+ gpt_response = self.openaigpt.safe_chat_complete(messages, content_only=False)
+
+ prompt_tokens = gpt_response['usage']['prompt_tokens']
+ completion_tokens = gpt_response['usage']['completion_tokens']
+
+ gpt_response = gpt_response['choices'][0]["message"]['content']
+
+ gpt_score, reason = self.parse_gpt_response_evaluate(gpt_response, ground_truth) # return 0, "INVALID", gpt_response if not valid
+
+ return object_id, model_output, ground_truth, gpt_score, reason, prompt_tokens, completion_tokens
+
+ def evaluate(self):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting single-thread evaluation...")
+ results = self.results
+
+ try:
+ for result in tqdm(results):
+ object_id, model_output, ground_truth, gpt_score, reason, prompt_tokens, completion_tokens = self.evaluate_result(result)
+
+ self.total_scores += gpt_score if gpt_score != -1 else 0
+ self.total_predictions += 1
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ if gpt_score == -1:
+ self.invalid_responses += 1
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'model_output': model_output,
+ "gpt_score": gpt_score,
+ 'gpt_reason': reason
+ })
+
+ print("Evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ self.save_results(is_temp=True)
+ exit()
+
+ def parallel_evaluate(self, num_workers=20):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting parallel evaluation...")
+ results = self.results
+
+ try:
+ with Pool(num_workers) as pool:
+ with tqdm(total=len(results)) as pbar: # create a progress bar
+ for object_id, model_output, ground_truth, gpt_score, reason, prompt_tokens, completion_tokens in pool.imap_unordered(self.evaluate_result, results):
+ self.total_scores += gpt_score if gpt_score != -1 else 0
+ self.total_predictions += 1
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ if gpt_score == -1:
+ self.invalid_responses += 1
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'model_output': model_output,
+ "gpt_score": gpt_score,
+ 'gpt_reason': reason
+ })
+
+ pbar.update() # update the progress bar
+
+ print("Parallel evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ self.save_results(is_temp=True)
+ exit()
+
+ def save_results(self, is_temp=False):
+ if is_temp:
+ output_path = os.path.join(self.output_dir, self.temp_output_file)
+ else:
+ output_path = os.path.join(self.output_dir, self.output_file)
+ if self.total_predictions - self.invalid_responses == 0:
+ average_score = 0 # * no results and get error
+ else:
+ average_score = self.total_scores / (self.total_predictions - self.invalid_responses)
+ with open(output_path, 'w') as f:
+ results_to_save = {
+ 'inference_prompt': self.inference_prompt,
+ 'gpt_prompt': self.gpt_prompt,
+ 'average_score': f"{average_score:.2f}",
+ 'total_score': f"{self.total_scores:.2f}",
+ 'total_predictions': self.total_predictions,
+ 'invalid_responses': self.invalid_responses,
+ 'prompt_tokens': self.prompt_tokens,
+ 'completion_tokens': self.completion_tokens,
+ 'GPT_cost': self.get_costs(),
+ 'results': self.response_data,
+ }
+ json.dump(results_to_save, f, indent=2)
+
+ print(f"Results saved to {output_path}")
+ # * print the length of saved results
+ print(f"Saved {len(self.response_data)} results in total.")
+
+ def print_results(self):
+ print('-' * 80)
+ if self.total_predictions - self.invalid_responses == 0:
+ average_score = 0 # * no results and get error
+ else:
+ average_score = self.total_scores / (self.total_predictions - self.invalid_responses)
+ print("Results:")
+ print(f"Average Score: {average_score:.2f}")
+ print(f"Total Predictions: {self.total_predictions}")
+ print(f"Invalid Responses: {self.invalid_responses}")
+ print(f"Prompt Tokens: {self.prompt_tokens}")
+ print(f"Completion Tokens: {self.completion_tokens}")
+
+ self.print_costs()
+
+
+def start_evaluation(results, output_dir, output_file, eval_type="open-free-form-classification", model_type="gpt-3.5-turbo-0613",
+ parallel=True, num_workers=20):
+ """
+ Args:
+ results: dict or file path to the json file containing the dict
+ output_file: the path the final evaluation results to be saved.
+ """
+ if isinstance(results, str):
+ with open(results, 'r') as fp:
+ results = json.load(fp)
+
+ if eval_type == "open-free-form-classification":
+ evaluator = OpenAIOpenFreeFormClsEvaluator(results, output_dir, output_file, model_type=model_type)
+ elif eval_type == "modelnet-close-set-classification":
+ evaluator = OpenAICloseSetClsEvaluator(results, output_dir, output_file, model_type=model_type)
+ elif eval_type == "object-captioning":
+ evaluator = OpenAIObjectCaptioningEvaluator(results, output_dir, output_file, model_type=model_type)
+ else:
+ raise NotImplementedError(f"eval_type {eval_type} not supported.")
+
+ if parallel:
+ evaluator.parallel_evaluate(num_workers=num_workers)
+ else:
+ evaluator.evaluate()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--results_path", type=str, \
+ default="", help="Path to the results file.")
+ parser.add_argument("--output_dir", type=str, default=None, help="Path to the output directory.")
+ parser.add_argument("--model_type", type=str, default="gpt-4-0613", choices=["gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-4-0613", "gpt-4-1106-preview"], help="Type of the model used to evaluate.")
+ parser.add_argument("--parallel", default=True, action="store_true", help="Whether to use parallel evaluation.")
+ parser.add_argument("--num_workers", type=int, default=15, help="Number of workers to use for parallel evaluation.")
+ parser.add_argument("--eval_type", type=str, choices=["modelnet-close-set-classification", "open-free-form-classification", "object-captioning"], default="object-captioning")
+
+ args = parser.parse_args()
+
+ if args.output_dir is None:
+ args.output_dir = os.path.dirname(args.results_path)
+
+ output_file = os.path.basename(args.results_path).replace(".json", f"_evaluated_{args.model_type}.json")
+
+ # if exists, then exit
+ if os.path.exists(os.path.join(args.output_dir, output_file)):
+ print(f"[INFO] Evaulated results already exists in {os.path.join(args.output_dir, output_file)}.")
+ exit()
+
+ start_evaluation(results=args.results_path, output_dir=args.output_dir, output_file=output_file, eval_type=args.eval_type, model_type=args.model_type,
+ parallel=args.parallel, num_workers=args.num_workers)
+
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/eval/traditional_evaluator.py b/ThirdParty/PointLLM/pointllm/eval/traditional_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..29a4c337b251fac5016bb49f9593b12ec2c7ff95
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/eval/traditional_evaluator.py
@@ -0,0 +1,179 @@
+import argparse
+import json
+import os
+import random
+random.seed(0)
+
+import nltk
+nltk.download('wordnet')
+from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
+from nltk.translate.meteor_score import meteor_score
+from rouge import Rouge
+from sentence_transformers import SentenceTransformer, util
+from scipy.spatial.distance import cosine
+from transformers import AutoModel, AutoTokenizer
+import torch
+
+
+import numpy as np
+from tqdm import tqdm
+
+class TraditionalMetricEvaluator():
+ def __init__(self, inputs, output_dir, output_file):
+ self.results = inputs['results']
+ self.inference_prompt = inputs['prompt']
+ self.output_dir = output_dir
+ self.output_file = output_file
+ self.rouge = Rouge()
+ self.response_data = []
+
+ self.ground_truths = []
+ self.generated_captions = []
+
+ self.sbert_model = SentenceTransformer('all-mpnet-base-v2')
+
+ self.simcse_tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-roberta-large")
+ self.simcse_model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-roberta-large")
+
+ self.scores = {
+ 'bleu-1': [],
+ 'bleu-2': [],
+ 'bleu-3': [],
+ 'bleu-4': [],
+ 'rouge-1': [],
+ 'rouge-2': [],
+ 'rouge-l': [],
+ 'meteor': [],
+ 'sbert_similarity': [],
+ 'simcse_similarity': []
+ }
+
+ def evaluate_result(self, result):
+ object_id = result['object_id']
+ ground_truth = result['ground_truth']
+ model_output = result['model_output']
+
+ if model_output == "":
+ # * all score should be 0
+ model_output = "##"
+
+ # create a SmoothingFunction object
+ smoothing_function = SmoothingFunction().method1 # * used to deal with non-overlap n-gram
+
+ # calculate BLEU-1 score with smoothing function
+ bleu_1_score = sentence_bleu([ground_truth.split()], model_output.split(), weights=(1, 0, 0, 0), smoothing_function=smoothing_function)
+
+ # calculate BLEU-2, BLEU-3, and BLEU-4 scores
+ bleu_2_score = sentence_bleu([ground_truth.split()], model_output.split(), weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing_function)
+ bleu_3_score = sentence_bleu([ground_truth.split()], model_output.split(), weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothing_function)
+ bleu_4_score = sentence_bleu([ground_truth.split()], model_output.split(), weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing_function)
+
+ # calculate ROUGE-L score
+ rouge_scores_l = self.rouge.get_scores(model_output, ground_truth)[0]['rouge-l']
+ rouge_scores_1 = self.rouge.get_scores(model_output, ground_truth)[0]['rouge-1']
+ rouge_scores_2 = self.rouge.get_scores(model_output, ground_truth)[0]['rouge-2']
+
+ # calculate METEOR score
+ meteor_scores = meteor_score([ground_truth.split()], model_output.split())
+
+ # Calculate SBERT similarity
+ embeddings = self.sbert_model.encode([ground_truth, model_output])
+ sbert_similarity = util.cos_sim(embeddings[0], embeddings[1])[0][0].item()
+
+ # calculate SimCSE similarity
+ # Tokenize input texts
+ inputs = self.simcse_tokenizer([ground_truth, model_output], padding=True, truncation=True, return_tensors="pt")
+
+ # Get the embeddings
+ with torch.no_grad():
+ embeddings = self.simcse_model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
+
+ # Calculate cosine similarity
+ simcse_similarity = 1 - cosine(embeddings[0], embeddings[1]) # * consine actually calculates consine distance, which is 1 - consine similarity
+
+ scores = {
+ 'bleu-1': bleu_1_score * 100,
+ 'bleu-2': bleu_2_score * 100,
+ 'bleu-3': bleu_3_score * 100,
+ 'bleu-4': bleu_4_score * 100,
+ 'rouge-l': rouge_scores_l['f'] * 100,
+ 'rouge-1': rouge_scores_1['f'] * 100,
+ 'rouge-2': rouge_scores_2['f'] * 100,
+ 'meteor': meteor_scores * 100,
+ 'sbert_similarity': sbert_similarity * 100,
+ 'simcse_similarity': simcse_similarity * 100
+ }
+
+ return object_id, model_output, ground_truth, scores
+
+ def evaluate(self):
+ print("Starting evaluation...")
+
+ for result in tqdm(self.results, desc="Evaluating"):
+ object_id, model_output, ground_truth, scores = self.evaluate_result(result)
+
+ # save the object_id, model_output, ground_truth, and scores for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'model_output': model_output,
+ 'scores': scores,
+ })
+
+ # save the scores for overall results
+ for metric, score in scores.items():
+ self.scores[metric].append(score)
+
+ print("Evaluation finished.")
+ self.save_results()
+ self.print_results()
+
+ def save_results(self):
+ output_path = os.path.join(self.output_dir, self.output_file)
+
+ with open(output_path, 'w') as f:
+ results_to_save = {
+ 'inference_prompt': self.inference_prompt,
+ 'overall_scores': {metric: f"{np.mean(scores):.4f}" for metric, scores in self.scores.items()},
+ 'results': self.response_data,
+ }
+ json.dump(results_to_save, f, indent=2)
+
+ print(f"Results saved to {output_path}")
+
+ def print_results(self):
+ print('-' * 80)
+ print("Results:")
+ for metric, scores in self.scores.items():
+ print(f"Average {metric.upper()} Score: {np.mean(scores):.4f}")
+
+def start_evaluation(results, output_dir, output_file,
+ parallel=True, num_workers=20):
+ """
+ Args:
+ results: dict or file path to the json file containing the dict
+ output_file: the path the final evaluation results to be saved.
+ """
+ if isinstance(results, str):
+ with open(results, 'r') as fp:
+ results = json.load(fp)
+
+ evaluator = TraditionalMetricEvaluator(results, output_dir, output_file)
+ evaluator.evaluate()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--results_path", type=str, \
+ default="", help="Path to the results file.")
+ parser.add_argument("--output_dir", type=str, default=None, help="Path to the output directory.")
+
+ args = parser.parse_args()
+
+ if args.output_dir is None:
+ args.output_dir = os.path.dirname(args.results_path)
+
+ output_file = os.path.basename(args.results_path).replace(".json", f"_evaluated_traditional.json")
+
+ start_evaluation(results=args.results_path, output_dir=args.output_dir, output_file=output_file)
+
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/eval/utils.py b/ThirdParty/PointLLM/pointllm/eval/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ee145fac7ffd185128ae08f4447d7102ee72a62
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/eval/utils.py
@@ -0,0 +1,69 @@
+import openai
+import time
+import random
+import os
+
+def retry_with_exponential_backoff(
+ func,
+ initial_delay: float = 1,
+ exponential_base: float = 2,
+ jitter: bool = True,
+ max_retries: int = 40,
+ max_delay: int = 30,
+ errors: tuple = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.Timeout),
+):
+ """Retry a function with exponential backoff."""
+ def wrapper(*args, **kwargs):
+ num_retries = 0
+ delay = initial_delay
+
+ while True:
+ try:
+ return func(*args, **kwargs)
+ except errors as e:
+ # * print the error info
+ num_retries += 1
+ if num_retries > max_retries:
+ print(f"[OPENAI] Encounter error: {e}.")
+ raise Exception(
+ f"[OPENAI] Maximum number of retries ({max_retries}) exceeded."
+ )
+ delay *= exponential_base * (1 + jitter * random.random())
+ time.sleep(min(delay, max_delay))
+ except Exception as e:
+ raise e
+ return wrapper
+
+class OpenAIGPT():
+ def __init__(self, model="gpt-3.5-turbo-0613", temperature=1, top_p=1, max_tokens=2048, **kwargs) -> None:
+ setup_openai(model)
+ self.default_chat_parameters = {
+ "model": model,
+ "temperature": temperature,
+ "top_p": top_p,
+ "max_tokens": max_tokens,
+ **kwargs
+ }
+
+ @retry_with_exponential_backoff
+ def safe_chat_complete(self, messages, content_only=True, **kwargs):
+ chat_parameters = self.default_chat_parameters.copy()
+ if len(kwargs) > 0:
+ chat_parameters.update(**kwargs)
+
+ response = openai.ChatCompletion.create(
+ messages=messages,
+ **chat_parameters
+ )
+
+ if content_only:
+ response = response['choices'][0]["message"]['content']
+
+ return response
+
+def setup_openai(model_name):
+ # Setup OpenAI API Key
+ print("[OPENAI] Setting OpenAI api_key...")
+ openai.api_key = os.getenv('OPENAI_API_KEY')
+ print(f"[OPENAI] OpenAI organization: {openai.organization}")
+ print(f"[OPENAI] Using MODEL: {model_name}")
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/model/__init__.py b/ThirdParty/PointLLM/pointllm/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ea5e0477ba727cf099c4fcfb89e9dcf59ec2be0
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/model/__init__.py
@@ -0,0 +1,2 @@
+# from .pointllm import PointLLMLlamaForCausalLM, PointLLMConfig
+from .pointbert.point_encoder import PointTransformer
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/model/pointbert/PointTransformer_8192point_2layer.yaml b/ThirdParty/PointLLM/pointllm/model/pointbert/PointTransformer_8192point_2layer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a90473b82e3afaa9654c2f7127c8e01d11006e4c
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/model/pointbert/PointTransformer_8192point_2layer.yaml
@@ -0,0 +1,16 @@
+model : {
+ NAME: PointTransformer,
+ trans_dim: 384,
+ depth: 12,
+ drop_path_rate: 0.1,
+ cls_dim: 40,
+ num_heads: 6,
+ group_size: 32,
+ num_group: 512,
+ encoder_dims: 256,
+ point_dims: 3,
+ projection_hidden_layer: 2,
+ projection_hidden_dim: [1024, 2048],
+ use_max_pool: false
+}
+npoints: 8192
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/model/pointbert/PointTransformer_base_8192point.yaml b/ThirdParty/PointLLM/pointllm/model/pointbert/PointTransformer_base_8192point.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ac9db169433888af8cb9eed641f327eb8b00536d
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/model/pointbert/PointTransformer_base_8192point.yaml
@@ -0,0 +1,13 @@
+model : {
+ NAME: PointTransformer,
+ trans_dim: 1152, # * point feature dims (hidden state)
+ depth: 12,
+ drop_path_rate: 0.1,
+ cls_dim: 40,
+ num_heads: 12,
+ group_size: 48,
+ num_group: 512,
+ encoder_dims: 512, # * point group tokens feature
+ point_input_dims: 3,
+}
+npoints: 8192
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/model/pointbert/checkpoint.py b/ThirdParty/PointLLM/pointllm/model/pointbert/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ac680ab712235a4b8f4cc74f4c36b969ad6e57b
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/model/pointbert/checkpoint.py
@@ -0,0 +1,126 @@
+from collections import defaultdict
+import torch.nn as nn
+
+from typing import Any
+from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
+
+from termcolor import colored
+
+def get_missing_parameters_message(keys: List[str]) -> str:
+ """
+ Get a logging-friendly message to report parameter names (keys) that are in
+ the model but not found in a checkpoint.
+ Args:
+ keys (list[str]): List of keys that were not found in the checkpoint.
+ Returns:
+ str: message.
+ """
+ groups = _group_checkpoint_keys(keys)
+ msg = "Some model parameters or buffers are not found in the checkpoint:\n"
+ msg += "\n".join(
+ " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
+ )
+ return msg
+
+
+def get_unexpected_parameters_message(keys: List[str]) -> str:
+ """
+ Get a logging-friendly message to report parameter names (keys) that are in
+ the checkpoint but not found in the model.
+ Args:
+ keys (list[str]): List of keys that were not found in the model.
+ Returns:
+ str: message.
+ """
+ groups = _group_checkpoint_keys(keys)
+ msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
+ msg += "\n".join(
+ " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
+ )
+ return msg
+
+
+def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
+ """
+ Strip the prefix in metadata, if any.
+ Args:
+ state_dict (OrderedDict): a state-dict to be loaded to the model.
+ prefix (str): prefix.
+ """
+ keys = sorted(state_dict.keys())
+ if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
+ return
+
+ for key in keys:
+ newkey = key[len(prefix):]
+ state_dict[newkey] = state_dict.pop(key)
+
+ # also strip the prefix in metadata, if any..
+ try:
+ metadata = state_dict._metadata # pyre-ignore
+ except AttributeError:
+ pass
+ else:
+ for key in list(metadata.keys()):
+ # for the metadata dict, the key can be:
+ # '': for the DDP module, which we want to remove.
+ # 'module': for the actual model.
+ # 'module.xx.xx': for the rest.
+
+ if len(key) == 0:
+ continue
+ newkey = key[len(prefix):]
+ metadata[newkey] = metadata.pop(key)
+
+
+def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
+ """
+ Group keys based on common prefixes. A prefix is the string up to the final
+ "." in each key.
+ Args:
+ keys (list[str]): list of parameter names, i.e. keys in the model
+ checkpoint dict.
+ Returns:
+ dict[list]: keys with common prefixes are grouped into lists.
+ """
+ groups = defaultdict(list)
+ for key in keys:
+ pos = key.rfind(".")
+ if pos >= 0:
+ head, tail = key[:pos], [key[pos + 1:]]
+ else:
+ head, tail = key, []
+ groups[head].extend(tail)
+ return groups
+
+
+def _group_to_str(group: List[str]) -> str:
+ """
+ Format a group of parameter name suffixes into a loggable string.
+ Args:
+ group (list[str]): list of parameter name suffixes.
+ Returns:
+ str: formated string.
+ """
+ if len(group) == 0:
+ return ""
+
+ if len(group) == 1:
+ return "." + group[0]
+
+ return ".{" + ", ".join(group) + "}"
+
+
+def _named_modules_with_dup(
+ model: nn.Module, prefix: str = ""
+) -> Iterable[Tuple[str, nn.Module]]:
+ """
+ The same as `model.named_modules()`, except that it includes
+ duplicated modules that have more than one name.
+ """
+ yield prefix, model
+ for name, module in model._modules.items(): # pyre-ignore
+ if module is None:
+ continue
+ submodule_prefix = prefix + ("." if prefix else "") + name
+ yield from _named_modules_with_dup(module, submodule_prefix)
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/model/pointbert/dvae.py b/ThirdParty/PointLLM/pointllm/model/pointbert/dvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..056c025bbc9b0ba0eab61ba1163c79745214fe81
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/model/pointbert/dvae.py
@@ -0,0 +1,355 @@
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+from . import misc
+
+# from knn_cuda import KNN
+
+# knn = KNN(k=4, transpose_mode=False)
+
+
+class DGCNN(nn.Module):
+ def __init__(self, encoder_channel, output_channel):
+ super().__init__()
+ '''
+ K has to be 16
+ '''
+ self.input_trans = nn.Conv1d(encoder_channel, 128, 1)
+
+ self.layer1 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, bias=False),
+ nn.GroupNorm(4, 256),
+ nn.LeakyReLU(negative_slope=0.2)
+ )
+
+ self.layer2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=1, bias=False),
+ nn.GroupNorm(4, 512),
+ nn.LeakyReLU(negative_slope=0.2)
+ )
+
+ self.layer3 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=1, bias=False),
+ nn.GroupNorm(4, 512),
+ nn.LeakyReLU(negative_slope=0.2)
+ )
+
+ self.layer4 = nn.Sequential(nn.Conv2d(1024, 1024, kernel_size=1, bias=False),
+ nn.GroupNorm(4, 1024),
+ nn.LeakyReLU(negative_slope=0.2)
+ )
+
+ self.layer5 = nn.Sequential(nn.Conv1d(2304, output_channel, kernel_size=1, bias=False),
+ nn.GroupNorm(4, output_channel),
+ nn.LeakyReLU(negative_slope=0.2)
+ )
+
+ @staticmethod
+ def get_graph_feature(coor_q, x_q, coor_k, x_k):
+ # coor: bs, 3, np, x: bs, c, np
+
+ k = 4
+ batch_size = x_k.size(0)
+ num_points_k = x_k.size(2)
+ num_points_q = x_q.size(2)
+
+ with torch.no_grad():
+ _, idx = knn(coor_k, coor_q) # bs k np
+ assert idx.shape[1] == k
+ idx_base = torch.arange(0, batch_size, device=x_q.device).view(-1, 1, 1) * num_points_k
+ idx = idx + idx_base
+ idx = idx.view(-1)
+ num_dims = x_k.size(1)
+ x_k = x_k.transpose(2, 1).contiguous()
+ feature = x_k.view(batch_size * num_points_k, -1)[idx, :]
+ feature = feature.view(batch_size, k, num_points_q, num_dims).permute(0, 3, 2, 1).contiguous()
+ x_q = x_q.view(batch_size, num_dims, num_points_q, 1).expand(-1, -1, -1, k)
+ feature = torch.cat((feature - x_q, x_q), dim=1)
+ return feature
+
+ def forward(self, f, coor):
+ # f: B G C
+ # coor: B G 3
+
+ # bs 3 N bs C N
+ feature_list = []
+ coor = coor.transpose(1, 2).contiguous() # B 3 N
+ f = f.transpose(1, 2).contiguous() # B C N
+ f = self.input_trans(f) # B 128 N
+
+ f = self.get_graph_feature(coor, f, coor, f) # B 256 N k
+ f = self.layer1(f) # B 256 N k
+ f = f.max(dim=-1, keepdim=False)[0] # B 256 N
+ feature_list.append(f)
+
+ f = self.get_graph_feature(coor, f, coor, f) # B 512 N k
+ f = self.layer2(f) # B 512 N k
+ f = f.max(dim=-1, keepdim=False)[0] # B 512 N
+ feature_list.append(f)
+
+ f = self.get_graph_feature(coor, f, coor, f) # B 1024 N k
+ f = self.layer3(f) # B 512 N k
+ f = f.max(dim=-1, keepdim=False)[0] # B 512 N
+ feature_list.append(f)
+
+ f = self.get_graph_feature(coor, f, coor, f) # B 1024 N k
+ f = self.layer4(f) # B 1024 N k
+ f = f.max(dim=-1, keepdim=False)[0] # B 1024 N
+ feature_list.append(f)
+
+ f = torch.cat(feature_list, dim=1) # B 2304 N
+
+ f = self.layer5(f) # B C' N
+
+ f = f.transpose(-1, -2)
+
+ return f
+
+
+### ref https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py ###
+def knn_point(nsample, xyz, new_xyz):
+ """
+ Input:
+ nsample: max sample number in local region
+ xyz: all points, [B, N, C]
+ new_xyz: query points, [B, S, C]
+ Return:
+ group_idx: grouped points index, [B, S, nsample]
+ """
+ sqrdists = square_distance(new_xyz, xyz)
+ _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
+ return group_idx
+
+
+def square_distance(src, dst):
+ """
+ Calculate Euclid distance between each two points.
+ src^T * dst = xn * xm + yn * ym + zn * zm;
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
+ Input:
+ src: source points, [B, N, C]
+ dst: target points, [B, M, C]
+ Output:
+ dist: per-point square distance, [B, N, M]
+ """
+ B, N, _ = src.shape
+ _, M, _ = dst.shape
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
+ return dist
+
+
+class Group(nn.Module):
+ def __init__(self, num_group, group_size):
+ super().__init__()
+ self.num_group = num_group
+ self.group_size = group_size
+ # self.knn = KNN(k=self.group_size, transpose_mode=True)
+
+ def forward(self, xyz):
+ '''
+ input: B N 3
+ ---------------------------
+ output: B G M 3
+ center : B G 3
+ '''
+ B, N, C = xyz.shape
+ if C > 3:
+ data = xyz
+ xyz = data[:,:,:3]
+ rgb = data[:, :, 3:]
+ batch_size, num_points, _ = xyz.shape
+ # fps the centers out
+ center = misc.fps(xyz, self.num_group) # B G 3
+
+ # knn to get the neighborhood
+ # _, idx = self.knn(xyz, center) # B G M
+ idx = knn_point(self.group_size, xyz, center) # B G M
+ assert idx.size(1) == self.num_group
+ assert idx.size(2) == self.group_size
+ idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
+ idx = idx + idx_base
+ idx = idx.view(-1)
+
+ neighborhood_xyz = xyz.view(batch_size * num_points, -1)[idx, :]
+ neighborhood_xyz = neighborhood_xyz.view(batch_size, self.num_group, self.group_size, 3).contiguous()
+ if C > 3:
+ neighborhood_rgb = rgb.view(batch_size * num_points, -1)[idx, :]
+ neighborhood_rgb = neighborhood_rgb.view(batch_size, self.num_group, self.group_size, -1).contiguous()
+
+ # normalize xyz
+ neighborhood_xyz = neighborhood_xyz - center.unsqueeze(2)
+ if C > 3:
+ neighborhood = torch.cat((neighborhood_xyz, neighborhood_rgb), dim=-1)
+ else:
+ neighborhood = neighborhood_xyz
+ return neighborhood, center
+
+class Encoder(nn.Module):
+ def __init__(self, encoder_channel, point_input_dims=3):
+ super().__init__()
+ self.encoder_channel = encoder_channel
+ self.point_input_dims = point_input_dims
+ self.first_conv = nn.Sequential(
+ nn.Conv1d(self.point_input_dims, 128, 1),
+ nn.BatchNorm1d(128),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(128, 256, 1)
+ )
+ self.second_conv = nn.Sequential(
+ nn.Conv1d(512, 512, 1),
+ nn.BatchNorm1d(512),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(512, self.encoder_channel, 1)
+ )
+
+ def forward(self, point_groups):
+ '''
+ point_groups : B G N 3
+ -----------------
+ feature_global : B G C
+ '''
+ bs, g, n, c = point_groups.shape
+ point_groups = point_groups.reshape(bs * g, n, c)
+ # encoder
+ feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n
+ feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1
+ feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n
+ feature = self.second_conv(feature) # BG 1024 n
+ feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
+ return feature_global.reshape(bs, g, self.encoder_channel)
+
+
+class Decoder(nn.Module):
+ def __init__(self, encoder_channel, num_fine):
+ super().__init__()
+ self.num_fine = num_fine
+ self.grid_size = 2
+ self.num_coarse = self.num_fine // 4
+ assert num_fine % 4 == 0
+
+ self.mlp = nn.Sequential(
+ nn.Linear(encoder_channel, 1024),
+ nn.ReLU(inplace=True),
+ nn.Linear(1024, 1024),
+ nn.ReLU(inplace=True),
+ nn.Linear(1024, 3 * self.num_coarse)
+ )
+ self.final_conv = nn.Sequential(
+ nn.Conv1d(encoder_channel + 3 + 2, 512, 1),
+ nn.BatchNorm1d(512),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(512, 512, 1),
+ nn.BatchNorm1d(512),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(512, 3, 1)
+ )
+ a = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(
+ self.grid_size, self.grid_size).reshape(1, -1)
+ b = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(
+ self.grid_size, self.grid_size).reshape(1, -1)
+ self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2) # 1 2 S
+
+ def forward(self, feature_global):
+ '''
+ feature_global : B G C
+ -------
+ coarse : B G M 3
+ fine : B G N 3
+
+ '''
+ bs, g, c = feature_global.shape
+ feature_global = feature_global.reshape(bs * g, c)
+
+ coarse = self.mlp(feature_global).reshape(bs * g, self.num_coarse, 3) # BG M 3
+
+ point_feat = coarse.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1) # BG (M) S 3
+ point_feat = point_feat.reshape(bs * g, self.num_fine, 3).transpose(2, 1) # BG 3 N
+
+ seed = self.folding_seed.unsqueeze(2).expand(bs * g, -1, self.num_coarse, -1) # BG 2 M (S)
+ seed = seed.reshape(bs * g, -1, self.num_fine).to(feature_global.device) # BG 2 N
+
+ feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_fine) # BG 1024 N
+ feat = torch.cat([feature_global, seed, point_feat], dim=1) # BG C N
+
+ center = coarse.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1) # BG (M) S 3
+ center = center.reshape(bs * g, self.num_fine, 3).transpose(2, 1) # BG 3 N
+
+ fine = self.final_conv(feat) + center # BG 3 N
+ fine = fine.reshape(bs, g, 3, self.num_fine).transpose(-1, -2)
+ coarse = coarse.reshape(bs, g, self.num_coarse, 3)
+ return coarse, fine
+
+
+class DiscreteVAE(nn.Module):
+ def __init__(self, config, **kwargs):
+ super().__init__()
+ self.group_size = config.group_size
+ self.num_group = config.num_group
+ self.encoder_dims = config.encoder_dims
+ self.tokens_dims = config.tokens_dims
+
+ self.decoder_dims = config.decoder_dims
+ self.num_tokens = config.num_tokens
+
+ self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
+ self.encoder = Encoder(encoder_channel=self.encoder_dims)
+ self.dgcnn_1 = DGCNN(encoder_channel=self.encoder_dims, output_channel=self.num_tokens)
+ self.codebook = nn.Parameter(torch.randn(self.num_tokens, self.tokens_dims))
+
+ self.dgcnn_2 = DGCNN(encoder_channel=self.tokens_dims, output_channel=self.decoder_dims)
+ self.decoder = Decoder(encoder_channel=self.decoder_dims, num_fine=self.group_size)
+ # self.build_loss_func()
+
+ # def build_loss_func(self):
+ # self.loss_func_cdl1 = ChamferDistanceL1().cuda()
+ # self.loss_func_cdl2 = ChamferDistanceL2().cuda()
+ # self.loss_func_emd = emd().cuda()
+
+ def recon_loss(self, ret, gt):
+ whole_coarse, whole_fine, coarse, fine, group_gt, _ = ret
+
+ bs, g, _, _ = coarse.shape
+
+ coarse = coarse.reshape(bs * g, -1, 3).contiguous()
+ fine = fine.reshape(bs * g, -1, 3).contiguous()
+ group_gt = group_gt.reshape(bs * g, -1, 3).contiguous()
+
+ loss_coarse_block = self.loss_func_cdl1(coarse, group_gt)
+ loss_fine_block = self.loss_func_cdl1(fine, group_gt)
+
+ loss_recon = loss_coarse_block + loss_fine_block
+
+ return loss_recon
+
+ def get_loss(self, ret, gt):
+ # reconstruction loss
+ loss_recon = self.recon_loss(ret, gt)
+ # kl divergence
+ logits = ret[-1] # B G N
+ softmax = F.softmax(logits, dim=-1)
+ mean_softmax = softmax.mean(dim=1)
+ log_qy = torch.log(mean_softmax)
+ log_uniform = torch.log(torch.tensor([1. / self.num_tokens], device=gt.device))
+ loss_klv = F.kl_div(log_qy, log_uniform.expand(log_qy.size(0), log_qy.size(1)), None, None, 'batchmean',
+ log_target=True)
+
+ return loss_recon, loss_klv
+
+ def forward(self, inp, temperature=1., hard=False, **kwargs):
+ neighborhood, center = self.group_divider(inp)
+ logits = self.encoder(neighborhood) # B G C
+ logits = self.dgcnn_1(logits, center) # B G N
+ soft_one_hot = F.gumbel_softmax(logits, tau=temperature, dim=2, hard=hard) # B G N
+ sampled = torch.einsum('b g n, n c -> b g c', soft_one_hot, self.codebook) # B G C
+ feature = self.dgcnn_2(sampled, center)
+ coarse, fine = self.decoder(feature)
+
+ with torch.no_grad():
+ whole_fine = (fine + center.unsqueeze(2)).reshape(inp.size(0), -1, 3)
+ whole_coarse = (coarse + center.unsqueeze(2)).reshape(inp.size(0), -1, 3)
+
+ assert fine.size(2) == self.group_size
+ ret = (whole_coarse, whole_fine, coarse, fine, neighborhood, logits)
+ return ret
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/model/pointbert/logger.py b/ThirdParty/PointLLM/pointllm/model/pointbert/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..847c1c7a2f50f310cd5daf96b928838c1c293525
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/model/pointbert/logger.py
@@ -0,0 +1,127 @@
+import logging
+import torch.distributed as dist
+
+logger_initialized = {}
+
+def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
+ """Get root logger and add a keyword filter to it.
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added. The name of the root logger is the top-level package name,
+ e.g., "mmdet3d".
+ Args:
+ log_file (str, optional): File path of log. Defaults to None.
+ log_level (int, optional): The level of logger.
+ Defaults to logging.INFO.
+ name (str, optional): The name of the root logger, also used as a
+ filter keyword. Defaults to 'mmdet3d'.
+ Returns:
+ :obj:`logging.Logger`: The obtained logger
+ """
+ logger = get_logger(name=name, log_file=log_file, log_level=log_level)
+ # add a logging filter
+ logging_filter = logging.Filter(name)
+ logging_filter.filter = lambda record: record.find(name) != -1
+
+ return logger
+
+
+def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
+ """Initialize and get a logger by name.
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
+ will also be added.
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ file_mode (str): The file mode used in opening log file.
+ Defaults to 'w'.
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ # handle hierarchical names
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
+ # initialization since it is a child of "a".
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ # handle duplicate logs to the console
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
+ # to the root logger. As logger.propagate is True by default, this root
+ # level handler causes logging messages from rank>0 processes to
+ # unexpectedly show up on the console, creating much unwanted clutter.
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
+ # at the ERROR level.
+ for handler in logger.root.handlers:
+ if type(handler) is logging.StreamHandler:
+ handler.setLevel(logging.ERROR)
+
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+
+ # only rank 0 will add a FileHandler
+ if rank == 0 and log_file is not None:
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
+ # provide an interface to change the file mode to the default
+ # behaviour.
+ file_handler = logging.FileHandler(log_file, file_mode)
+ handlers.append(file_handler)
+
+ formatter = logging.Formatter(
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ for handler in handlers:
+ handler.setFormatter(formatter)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+
+ if rank == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+
+ logger_initialized[name] = True
+
+
+ return logger
+
+
+def print_log(msg, logger=None, level=logging.INFO):
+ """Print a log message.
+ Args:
+ msg (str): The message to be logged.
+ logger (logging.Logger | str | None): The logger to be used.
+ Some special loggers are:
+ - "silent": no message will be printed.
+ - other str: the logger obtained with `get_root_logger(logger)`.
+ - None: The `print()` method will be used to print log messages.
+ level (int): Logging level. Only available when `logger` is a Logger
+ object or "root".
+ """
+ if logger is None:
+ print(msg)
+ elif isinstance(logger, logging.Logger):
+ logger.log(level, msg)
+ elif logger == 'silent':
+ pass
+ elif isinstance(logger, str):
+ _logger = get_logger(logger)
+ _logger.log(level, msg)
+ else:
+ raise TypeError(
+ 'logger should be either a logging.Logger object, str, '
+ f'"silent" or None, but got {type(logger)}')
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/model/pointbert/misc.py b/ThirdParty/PointLLM/pointllm/model/pointbert/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..02071cb2e4f70b143c86c617f16d5922a88f24f6
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/model/pointbert/misc.py
@@ -0,0 +1,287 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from mpl_toolkits.mplot3d import Axes3D
+import random
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import os
+from collections import abc
+# from pointnet2_ops import pointnet2_utils
+
+
+# def fps(data, number):
+# '''
+# data B N 3
+# number int
+# '''
+# fps_idx = pointnet2_utils.furthest_point_sample(data, number)
+# fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
+# return fps_data
+
+def index_points(points, idx):
+ """
+ Input:
+ points: input points data, [B, N, C]
+ idx: sample index data, [B, S]
+ Return:
+ new_points:, indexed points data, [B, S, C]
+ """
+ device = points.device
+ B = points.shape[0]
+ view_shape = list(idx.shape)
+ view_shape[1:] = [1] * (len(view_shape) - 1)
+ repeat_shape = list(idx.shape)
+ repeat_shape[0] = 1
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
+ new_points = points[batch_indices, idx, :]
+ return new_points
+
+def fps(xyz, npoint):
+ """
+ Input:
+ xyz: pointcloud data, [B, N, 3]
+ npoint: number of samples
+ Return:
+ centroids: sampled pointcloud index, [B, npoint]
+ """
+ device = xyz.device
+ B, N, C = xyz.shape
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
+ distance = torch.ones(B, N).to(device) * 1e10
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
+ for i in range(npoint):
+ centroids[:, i] = farthest
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
+ dist = torch.sum((xyz - centroid) ** 2, -1)
+ distance = torch.min(distance, dist)
+ farthest = torch.max(distance, -1)[1]
+ return index_points(xyz, centroids)
+
+def worker_init_fn(worker_id):
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
+
+def build_lambda_sche(opti, config):
+ if config.get('decay_step') is not None:
+ lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)
+ scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)
+ else:
+ raise NotImplementedError()
+ return scheduler
+
+def build_lambda_bnsche(model, config):
+ if config.get('decay_step') is not None:
+ bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)
+ bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
+ else:
+ raise NotImplementedError()
+ return bnm_scheduler
+
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+
+ # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
+ if cuda_deterministic: # slower, more reproducible
+ cudnn.deterministic = True
+ cudnn.benchmark = False
+ else: # faster, less reproducible
+ cudnn.deterministic = False
+ cudnn.benchmark = True
+
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def is_seq_of(seq, expected_type, seq_type=None):
+ """Check whether it is a sequence of some type.
+ Args:
+ seq (Sequence): The sequence to be checked.
+ expected_type (type): Expected type of sequence items.
+ seq_type (type, optional): Expected sequence type.
+ Returns:
+ bool: Whether the sequence is valid.
+ """
+ if seq_type is None:
+ exp_seq_type = abc.Sequence
+ else:
+ assert isinstance(seq_type, type)
+ exp_seq_type = seq_type
+ if not isinstance(seq, exp_seq_type):
+ return False
+ for item in seq:
+ if not isinstance(item, expected_type):
+ return False
+ return True
+
+
+def set_bn_momentum_default(bn_momentum):
+ def fn(m):
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
+ m.momentum = bn_momentum
+ return fn
+
+class BNMomentumScheduler(object):
+
+ def __init__(
+ self, model, bn_lambda, last_epoch=-1,
+ setter=set_bn_momentum_default
+ ):
+ if not isinstance(model, nn.Module):
+ raise RuntimeError(
+ "Class '{}' is not a PyTorch nn Module".format(
+ type(model).__name__
+ )
+ )
+
+ self.model = model
+ self.setter = setter
+ self.lmbd = bn_lambda
+
+ self.step(last_epoch + 1)
+ self.last_epoch = last_epoch
+
+ def step(self, epoch=None):
+ if epoch is None:
+ epoch = self.last_epoch + 1
+
+ self.last_epoch = epoch
+ self.model.apply(self.setter(self.lmbd(epoch)))
+
+ def get_momentum(self, epoch=None):
+ if epoch is None:
+ epoch = self.last_epoch + 1
+ return self.lmbd(epoch)
+
+
+
+def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False):
+ '''
+ seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.
+ '''
+ _,n,c = xyz.shape
+
+ assert n == num_points
+ assert c == 3
+ if crop == num_points:
+ return xyz, None
+
+ INPUT = []
+ CROP = []
+ for points in xyz:
+ if isinstance(crop,list):
+ num_crop = random.randint(crop[0],crop[1])
+ else:
+ num_crop = crop
+
+ points = points.unsqueeze(0)
+
+ if fixed_points is None:
+ center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda()
+ else:
+ if isinstance(fixed_points,list):
+ fixed_point = random.sample(fixed_points,1)[0]
+ else:
+ fixed_point = fixed_points
+ center = fixed_point.reshape(1,1,3).cuda()
+
+ distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048
+
+ idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048
+
+ if padding_zeros:
+ input_data = points.clone()
+ input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0
+
+ else:
+ input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
+
+ crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
+
+ if isinstance(crop,list):
+ INPUT.append(fps(input_data,2048))
+ CROP.append(fps(crop_data,2048))
+ else:
+ INPUT.append(input_data)
+ CROP.append(crop_data)
+
+ input_data = torch.cat(INPUT,dim=0)# B N 3
+ crop_data = torch.cat(CROP,dim=0)# B M 3
+
+ return input_data.contiguous(), crop_data.contiguous()
+
+def get_ptcloud_img(ptcloud):
+ fig = plt.figure(figsize=(8, 8))
+
+ x, z, y = ptcloud.transpose(1, 0)
+ ax = fig.gca(projection=Axes3D.name, adjustable='box')
+ ax.axis('off')
+ # ax.axis('scaled')
+ ax.view_init(30, 45)
+ max, min = np.max(ptcloud), np.min(ptcloud)
+ ax.set_xbound(min, max)
+ ax.set_ybound(min, max)
+ ax.set_zbound(min, max)
+ ax.scatter(x, y, z, zdir='z', c=x, cmap='jet')
+
+ fig.canvas.draw()
+ img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
+ return img
+
+
+
+def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y',
+ xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ):
+ fig = plt.figure(figsize=(6*len(data_list),6))
+ cmax = data_list[-1][:,0].max()
+
+ for i in range(len(data_list)):
+ data = data_list[i][:-2048] if i == 1 else data_list[i]
+ color = data[:,0] /cmax
+ ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d')
+ ax.view_init(30, -120)
+ b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black')
+ ax.set_title(titles[i])
+
+ ax.set_axis_off()
+ ax.set_xlim(xlim)
+ ax.set_ylim(ylim)
+ ax.set_zlim(zlim)
+ plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ pic_path = path + '.png'
+ fig.savefig(pic_path)
+
+ np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())
+ np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())
+ plt.close(fig)
+
+
+def random_dropping(pc, e):
+ up_num = max(64, 768 // (e//50 + 1))
+ pc = pc
+ random_num = torch.randint(1, up_num, (1,1))[0,0]
+ pc = fps(pc, random_num)
+ padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)
+ pc = torch.cat([pc, padding], dim = 1)
+ return pc
+
+
+def random_scale(partial, scale_range=[0.8, 1.2]):
+ scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]
+ return partial * scale
diff --git a/ThirdParty/PointLLM/pointllm/model/pointbert/point_encoder.py b/ThirdParty/PointLLM/pointllm/model/pointbert/point_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e01a0186bdb6d18bc64f0c9838043854d635c645
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/model/pointbert/point_encoder.py
@@ -0,0 +1,189 @@
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath
+from .dvae import Group
+from .dvae import Encoder
+from .logger import print_log
+from collections import OrderedDict
+
+from .checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class TransformerEncoder(nn.Module):
+ """ Transformer Encoder without hierarchical structure
+ """
+
+ def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
+ super().__init__()
+
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
+ )
+ for i in range(depth)])
+
+ def forward(self, x, pos):
+ for _, block in enumerate(self.blocks):
+ x = block(x + pos)
+ return x
+
+
+class PointTransformer(nn.Module):
+ def __init__(self, config, use_max_pool=True):
+ super().__init__()
+ self.config = config
+
+ self.use_max_pool = use_max_pool # * whethet to max pool the features of different tokens
+
+ self.trans_dim = config.trans_dim
+ self.depth = config.depth
+ self.drop_path_rate = config.drop_path_rate
+ self.cls_dim = config.cls_dim
+ self.num_heads = config.num_heads
+
+ self.group_size = config.group_size
+ self.num_group = config.num_group
+ self.point_dims = config.point_dims
+ # grouper
+ self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
+ # define the encoder
+ self.encoder_dims = config.encoder_dims
+ self.encoder = Encoder(encoder_channel=self.encoder_dims, point_input_dims=self.point_dims)
+ # bridge encoder and transformer
+ self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim)
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
+ self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
+
+ self.pos_embed = nn.Sequential(
+ nn.Linear(3, 128),
+ nn.GELU(),
+ nn.Linear(128, self.trans_dim)
+ )
+
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
+ self.blocks = TransformerEncoder(
+ embed_dim=self.trans_dim,
+ depth=self.depth,
+ drop_path_rate=dpr,
+ num_heads=self.num_heads
+ )
+
+ self.norm = nn.LayerNorm(self.trans_dim)
+
+ def load_checkpoint(self, bert_ckpt_path):
+ ckpt = torch.load(bert_ckpt_path, map_location='cpu')
+ state_dict = OrderedDict()
+ for k, v in ckpt['state_dict'].items():
+ if k.startswith('module.point_encoder.'):
+ state_dict[k.replace('module.point_encoder.', '')] = v
+
+ incompatible = self.load_state_dict(state_dict, strict=False)
+
+ if incompatible.missing_keys:
+ print_log('missing_keys', logger='Transformer')
+ print_log(
+ get_missing_parameters_message(incompatible.missing_keys),
+ logger='Transformer'
+ )
+ if incompatible.unexpected_keys:
+ print_log('unexpected_keys', logger='Transformer')
+ print_log(
+ get_unexpected_parameters_message(incompatible.unexpected_keys),
+ logger='Transformer'
+ )
+ if not incompatible.missing_keys and not incompatible.unexpected_keys:
+ # * print successful loading
+ print_log("PointBERT's weights are successfully loaded from {}".format(bert_ckpt_path), logger='Transformer')
+
+ def forward(self, pts):
+ # divide the point cloud in the same form. This is important
+ neighborhood, center = self.group_divider(pts)
+ # encoder the input cloud blocks
+ group_input_tokens = self.encoder(neighborhood) # B G N
+ group_input_tokens = self.reduce_dim(group_input_tokens)
+ # prepare cls
+ cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
+ cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
+ # add pos embedding
+ pos = self.pos_embed(center)
+ # final input
+ x = torch.cat((cls_tokens, group_input_tokens), dim=1)
+ pos = torch.cat((cls_pos, pos), dim=1)
+ # transformer
+ x = self.blocks(x, pos)
+ x = self.norm(x) # * B, G + 1(cls token)(513), C(384)
+ if not self.use_max_pool:
+ return x
+ concat_f = torch.cat([x[:, 0], x[:, 1:].max(1)[0]], dim=-1).unsqueeze(1) # * concat the cls token and max pool the features of different tokens, make it B, 1, C
+ return concat_f # * B, 1, C(384 + 384)
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/pointllm/model/pointllm.py b/ThirdParty/PointLLM/pointllm/model/pointllm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ef8111218fe7e17c3bee2a958063afc57ac080a
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/model/pointllm.py
@@ -0,0 +1,353 @@
+# Copyright 2023 Runsen Xu
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+from .utils import *
+from ThirdParty.PointLLM.pointllm.utils import *
+
+from contextlib import nullcontext
+from transformers import AutoConfig, AutoModelForCausalLM, \
+ LlamaConfig, LlamaModel, LlamaForCausalLM
+
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+
+import os
+
+# * add logger
+import logging
+logger = logging.getLogger(__name__)
+
+class PointLLMConfig(LlamaConfig):
+ model_type = "pointllm"
+
+class PointLLMLlamaModel(LlamaModel):
+ config_class = PointLLMConfig
+
+ def __init__(self, config: LlamaConfig):
+ super(PointLLMLlamaModel, self).__init__(config)
+
+ self.point_backbone_type = config.point_backbone
+ logger.info(f"Using {self.point_backbone_type}.")
+
+ if self.point_backbone_type == "PointBERT":
+ from pointllm.model import PointTransformer
+ # address of config file, in the same dir of this file
+ point_bert_config_name = getattr(config, "point_backbone_config_name", "PointTransformer_8192point_2layer") # * default for v1.2, v1.1 uses PointTransformer_base_8192point.yaml
+ point_bert_config_addr = os.path.join(os.path.dirname(__file__), "pointbert", f"{point_bert_config_name}.yaml")
+ print(f"Loading PointBERT config from {point_bert_config_addr}.")
+ point_bert_config = cfg_from_yaml_file(point_bert_config_addr)
+ if getattr(config, "use_color", False):
+ point_bert_config.model.point_dims = 6
+ use_max_pool = getattr(point_bert_config.model, "use_max_pool", False) # * default is false
+
+ self.point_backbone = PointTransformer(point_bert_config.model, use_max_pool=use_max_pool)
+ logger.info(f"Using {self.point_backbone.point_dims} dim of points.")
+
+ self.point_backbone_config = {
+ "point_cloud_dim": point_bert_config.model.point_dims,
+ "backbone_output_dim": point_bert_config.model.trans_dim if not use_max_pool else point_bert_config.model.trans_dim * 2,
+ "project_output_dim": self.config.hidden_size,
+ "point_token_len": point_bert_config.model.num_group + 1 if not use_max_pool else 1, # * number of output features, with cls token
+ "mm_use_point_start_end": self.config.mm_use_point_start_end,
+ "projection_hidden_layer": point_bert_config.model.get('projection_hidden_layer', 0),
+ "use_max_pool": use_max_pool
+ }
+ if point_bert_config.model.get('projection_hidden_layer', 0) > 0:
+ self.point_backbone_config["projection_hidden_dim"] = point_bert_config.model.projection_hidden_dim # a list
+
+ logger.info(f"Use max pool is {use_max_pool}. Number of point token is {self.point_backbone_config['point_token_len']}.")
+
+ # * print relevant info with projection layers
+ backbone_output_dim = self.point_backbone_config["backbone_output_dim"]
+ logger.info(f"Point backbone output dim: {backbone_output_dim}.")
+ logger.info(f"Use {self.point_backbone_config['projection_hidden_layer']} projection hiddent layers.")
+ if self.point_backbone_config['projection_hidden_layer'] > 0:
+ # Add projection layer with linear layers and GELU activation
+ projection_layers = []
+ last_dim = backbone_output_dim
+ for i in range(point_bert_config.model.projection_hidden_layer):
+ projection_layers.append(nn.Linear(last_dim, self.point_backbone_config["projection_hidden_dim"][i]))
+ projection_layers.append(nn.GELU())
+ last_dim = self.point_backbone_config["projection_hidden_dim"][i]
+
+ projection_layers.append(nn.Linear(last_dim, self.point_backbone_config["project_output_dim"]))
+ self.point_proj = nn.Sequential(*projection_layers)
+ logger.info(f"Each layer with {point_bert_config.model.projection_hidden_dim} hidden units.")
+ else:
+ # Single layer
+ self.point_proj = nn.Linear(backbone_output_dim, self.point_backbone_config['project_output_dim'])
+ logger.info(f"Point projector output dim: {self.point_backbone_config['project_output_dim']}.")
+
+ self.fix_pointnet = False
+ self.fix_llm = False
+
+ def load_point_backbone_checkpoint(self, checkpoint_path=None):
+ self.point_backbone.load_checkpoint(self.config.point_backbone_ckpt if checkpoint_path is None else checkpoint_path)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ point_clouds: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+ # HACK: replace back original embeddings for pretraining
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ point_backbone = getattr(self, 'point_backbone', None)
+ point_backbone_config = getattr(self, 'point_backbone_config', None)
+
+ if point_backbone is not None and (input_ids.shape[1] != 1 or self.training) and point_clouds is not None:
+ # * enter when training or the first generation step of inference
+ with torch.no_grad() if self.fix_pointnet else nullcontext():
+ if self.fix_pointnet:
+ self.point_backbone.eval()
+ if type(point_clouds) is list:
+ # * variable numbers of points
+ point_features = []
+ for point_cloud in point_clouds: # * iterate over batch
+ point_feature = self.point_backbone(point_cloud.unsqueeze(0))[0]
+ point_features.append(point_feature)
+ else:
+ point_features = self.point_backbone(point_clouds)
+
+ if type(point_clouds) is list:
+ point_features = [self.point_proj(point_feature) for point_feature in point_features]
+ else:
+ point_features = self.point_proj(point_features)
+
+ dummy_point_features = torch.zeros(point_backbone_config['point_token_len'], point_backbone_config['backbone_output_dim'], device=inputs_embeds.device, dtype=inputs_embeds.dtype)
+ dummy_point_features = self.point_proj(dummy_point_features)
+
+ new_input_embeds = []
+ cur_point_idx = 0
+ for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): # * input_ids: B, L; input_embeds: B, L, C
+ if (cur_input_ids == point_backbone_config['point_patch_token']).sum() == 0:
+ # multimodal LLM, but the current sample is not multimodal
+ cur_input_embeds = cur_input_embeds + (0. * dummy_point_features).sum() # * do nothing
+ new_input_embeds.append(cur_input_embeds)
+ cur_point_idx += 1
+ continue
+ cur_point_features = point_features[cur_point_idx].to(device=cur_input_embeds.device)
+ num_patches = cur_point_features.shape[0] # * number of point tokens
+ if point_backbone_config['mm_use_point_start_end']:
+ if (cur_input_ids == point_backbone_config["point_start_token"]).sum() != (cur_input_ids == point_backbone_config["point_end_token"]).sum():
+ raise ValueError("The number of point start tokens and point end tokens should be the same.")
+ point_start_tokens = torch.where(cur_input_ids == point_backbone_config["point_start_token"])[0]
+ for point_start_token_pos in point_start_tokens:
+ if cur_input_ids[point_start_token_pos + num_patches + 1] != point_backbone_config["point_end_token"]:
+ raise ValueError("The point end token should follow the point start token.")
+ if orig_embeds_params is not None: # * will not update the original embeddings except for POINT_START_TOKEN and POINT_END_TOKEN
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:point_start_token_pos].detach(), cur_input_embeds[point_start_token_pos:point_start_token_pos+1], cur_point_features, cur_input_embeds[point_start_token_pos + num_patches + 1:point_start_token_pos + num_patches + 2], cur_input_embeds[point_start_token_pos + num_patches + 2:].detach()), dim=0)
+ else:
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:point_start_token_pos+1], cur_point_features, cur_input_embeds[point_start_token_pos + num_patches + 1:]), dim=0)
+ cur_point_idx += 1
+ new_input_embeds.append(cur_new_input_embeds)
+ else:
+ if (cur_input_ids == point_backbone_config["point_patch_token"]).sum() != num_patches:
+ raise ValueError("The number of point patch tokens should be the same as the number of point patches.")
+ masked_indices = torch.where(cur_input_ids == point_backbone_config["point_patch_token"])[0]
+ mask_index_start = masked_indices[0]
+ if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
+ raise ValueError("The point patch tokens should be consecutive.")
+ if orig_embeds_params is not None:
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_point_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
+ else:
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_point_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
+ new_input_embeds.append(cur_new_input_embeds)
+ cur_point_idx += 1
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
+
+ return super(PointLLMLlamaModel, self).forward(
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds, use_cache=use_cache,
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
+ return_dict=return_dict
+ )
+
+
+class PointLLMLlamaForCausalLM(LlamaForCausalLM):
+ config_class = PointLLMConfig
+
+ def __init__(self, config):
+ super(LlamaForCausalLM, self).__init__(config)
+ self.model = PointLLMLlamaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None, # * control whether to return past_key_values
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ point_clouds: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ point_clouds=point_clouds
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous() # * B, L, V(32003)
+ shift_labels = labels[..., 1:].contiguous() # * B, L
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model/pipeline parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "point_clouds": kwargs.get("point_clouds", None),
+ }
+ )
+ return model_inputs
+
+ def initialize_tokenizer_point_backbone_config_wo_embedding(self, tokenizer):
+ # * called when stage2 or inference or inference without pre-training, assume tokenizer has point tokens
+ config = self.config
+ point_backbone_config = self.get_model().point_backbone_config
+ mm_use_point_start_end = point_backbone_config['mm_use_point_start_end'] = config.mm_use_point_start_end
+
+ default_point_patch_token = config.DEFAULT_POINT_PATCH_TOKEN
+
+ tokenizer.add_tokens([default_point_patch_token], special_tokens=True)
+
+ # * assert tokenizer has the default_point_patch_token
+ point_backbone_config['default_point_patch_token'] = default_point_patch_token
+ point_backbone_config['point_patch_token'] = tokenizer.convert_tokens_to_ids([default_point_patch_token])[0]
+
+ if mm_use_point_start_end:
+ default_point_start_token = config.DEFAULT_POINT_START_TOKEN
+ default_point_end_token = config.DEFAULT_POINT_END_TOKEN
+ tokenizer.add_tokens([default_point_start_token, default_point_end_token], special_tokens=True)
+
+ point_backbone_config['default_point_start_token'] = default_point_start_token
+ point_backbone_config['default_point_end_token'] = default_point_end_token
+
+ point_backbone_config["point_start_token"] = tokenizer.convert_tokens_to_ids([default_point_start_token])[0]
+ point_backbone_config["point_end_token"] = tokenizer.convert_tokens_to_ids([default_point_end_token])[0]
+
+ def initialize_tokenizer_point_backbone_config(self, tokenizer, device, fix_llm=True):
+
+ config = self.config
+ point_backbone_config = self.get_model().point_backbone_config
+ mm_use_point_start_end = point_backbone_config['mm_use_point_start_end'] = config.mm_use_point_start_end
+
+ default_point_patch_token = config.DEFAULT_POINT_PATCH_TOKEN
+ point_backbone_config['default_point_patch_token'] = default_point_patch_token
+ tokenizer.add_tokens([default_point_patch_token], special_tokens=True) # * no need to update embed since it will be replaced
+ self.resize_token_embeddings(len(tokenizer)) # ! resize_token_embeddings will make the tokens trainable again
+ point_backbone_config['point_patch_token'] = tokenizer.convert_tokens_to_ids([default_point_patch_token])[0]
+
+ if mm_use_point_start_end:
+ default_point_start_token = config.DEFAULT_POINT_START_TOKEN
+ default_point_end_token = config.DEFAULT_POINT_END_TOKEN
+ point_backbone_config['default_point_start_token'] = default_point_start_token
+ point_backbone_config['default_point_end_token'] = default_point_end_token
+
+ num_new_tokens = tokenizer.add_tokens([default_point_start_token, default_point_end_token], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+ point_backbone_config["point_start_token"] = tokenizer.convert_tokens_to_ids([default_point_start_token])[0]
+ point_backbone_config["point_end_token"] = tokenizer.convert_tokens_to_ids([default_point_end_token])[0]
+
+ if num_new_tokens > 0:
+ input_embeddings = self.get_input_embeddings().weight.data
+ output_embeddings = self.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+ # need to update the input embeding, but no need to update the output embedding
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = True
+ if fix_llm:
+ self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] # * only tuning the new embeddings
+ for p in self.get_output_embeddings().parameters(): # * the llm head
+ p.requires_grad = False
+ print(f"Setting output embeddings fixed and {num_new_tokens} new tokens' input embeddings trainable.")
+ else:
+ self.get_model().orig_embeds_params = None
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = True
+ print("Setting output embeddings and all input embeddings trainable.")
+
+AutoConfig.register("pointllm", PointLLMConfig)
+AutoModelForCausalLM.register(PointLLMConfig, PointLLMLlamaForCausalLM)
diff --git a/ThirdParty/PointLLM/pointllm/model/utils.py b/ThirdParty/PointLLM/pointllm/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b78741ca050c66d3c3891a236715f30652130c97
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/model/utils.py
@@ -0,0 +1,24 @@
+import torch
+from transformers import StoppingCriteria
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
+ self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
+ self.tokenizer = tokenizer
+ self.start_len = None
+ self.input_ids = input_ids
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ if self.start_len is None:
+ self.start_len = self.input_ids.shape[1]
+ else:
+ for keyword_id in self.keyword_ids:
+ if output_ids[0, -1] == keyword_id:
+ return True
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
diff --git a/ThirdParty/PointLLM/pointllm/train/llama_flash_attn_monkey_patch.py b/ThirdParty/PointLLM/pointllm/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcd3ba7f9361649b5ba0e5a9db312e002c1cac44
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,107 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+from typing import List, Optional, Tuple
+from cv2 import exp
+
+import torch
+from torch import nn
+
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
+
+from einops import rearrange
+
+# * some version is changed to flash_attn_varlen_qkvpacked_func, so need to check
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+except:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel
+
+ attention_mask: [bsz, q_len]
+ """
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ # [bsz, q_len, nh, hd]
+ # [bsz, nh, q_len, hd]
+
+ kv_seq_len = key_states.shape[-2]
+ offset = 0
+ if past_key_value is not None:
+ offset = past_key_value[0].shape[-2]
+ kv_seq_len += offset
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states,
+ key_states,
+ cos,
+ sin,
+ offset=offset)
+ # [bsz, nh, t, hd]
+ assert not output_attentions, "output_attentions is not supported"
+ assert not use_cache, "use_cache is not supported"
+ assert past_key_value is None, "past_key_value is not supported"
+
+ # Flash attention codes from
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
+
+ # transform the data into the format required by flash attention
+ qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd]
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
+ # the attention_mask should be the same as the key_padding_mask
+ key_padding_mask = attention_mask
+
+
+ if key_padding_mask is None:
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ max_s = q_len
+ cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32,
+ device=qkv.device)
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0,
+ softmax_scale=None, causal=True
+ )
+ output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
+ else:
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ x_unpad, cu_q_lens, max_s, 0.0,
+ softmax_scale=None, causal=True
+ )
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
+ indices, bsz, q_len),
+ 'b s (h d) -> b s h d', h=nheads)
+ return self.o_proj(rearrange(output,
+ 'b s h d -> b s (h d)')), None, None
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
+ inputs_embeds, past_key_values_length):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/ThirdParty/PointLLM/pointllm/train/pointllm_trainer.py b/ThirdParty/PointLLM/pointllm/train/pointllm_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..096fa75d673e8f3f51b8cc33997c76cc927c1e9f
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/train/pointllm_trainer.py
@@ -0,0 +1,49 @@
+import os
+import torch
+import torch.nn as nn
+
+from transformers import Trainer
+from typing import Optional
+
+
+def unwrap_model(model: nn.Module) -> nn.Module:
+ """
+ Recursively unwraps a model from potential containers (as used in distributed training).
+
+ Args:
+ model (`torch.nn.Module`): The model to unwrap.
+ """
+ # since there could be multiple levels of wrapping, unwrap recursively
+ if hasattr(model, "module"):
+ return unwrap_model(model.module)
+ else:
+ return model
+
+
+class PointLLMTrainer(Trainer):
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ # Save the model
+ _state_dict = state_dict
+ if _state_dict is None:
+ # Only save the model itself if we are using distributed training
+ model_to_save = unwrap_model(self.model)
+ _state_dict = model_to_save.state_dict()
+
+ weight_to_save = {}
+ keys_to_match = ['point_proj', 'embed_tokens', 'embed_in']
+ for k, v in _state_dict.items():
+ if any(key_match in k for key_match in keys_to_match):
+ weight_to_save[k] = v
+
+ current_folder = output_dir.split('/')[-1]
+ parent_folder = os.path.dirname(output_dir)
+ if current_folder.startswith('checkpoint-'):
+ mm_projector_folder = os.path.join(parent_folder, "point_proj")
+ os.makedirs(mm_projector_folder, exist_ok=True)
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
+ else:
+ torch.save(weight_to_save, os.path.join(output_dir, f'point_proj.bin'))
+
+ super(PointLLMTrainer, self)._save(output_dir, state_dict)
diff --git a/ThirdParty/PointLLM/pointllm/train/train.py b/ThirdParty/PointLLM/pointllm/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c0f07f6980930fc991a05ac7a5aebf456f62879
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/train/train.py
@@ -0,0 +1,216 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+import pathlib
+from typing import Optional, List
+
+
+import transformers
+from pointllm.train.pointllm_trainer import PointLLMTrainer
+
+from pointllm import conversation as conversation_lib
+from pointllm.model import *
+from pointllm.data import make_object_point_data_module
+
+# * logger
+from pointllm.utils import build_logger
+
+IGNORE_INDEX = -100
+
+DEFAULT_PAD_TOKEN = "[PAD]"
+DEFAULT_EOS_TOKEN = ""
+DEFAULT_BOS_TOKEN = ""
+DEFAULT_UNK_TOKEN = ""
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="")
+ version: Optional[str] = field(default="v1")
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default="ScanNet", metadata={"help": "Path to the training data."})
+ anno_path: str = field(default=None, metadata={"help": "Path to the utterance data. If None, will use referit3d by defautl."})
+ use_color: bool = field(default=False, metadata={"help": "Whether to use color."})
+ data_debug_num: int = field(default=0, metadata={"help": "Number of data to use in debug mode. If larger than 0, use debug mode, else use the whole data"})
+ split_train_val: bool = field(default=False, metadata={"help": "Whether to split train and val."})
+ split_ratio: float = field(default=0.9, metadata={"help": "Ratio of train and val."})
+ pointnum: int = field(default=8192, metadata={"help": "Number of points."})
+ conversation_types: List[str] = field(default_factory=lambda: ["simple_description"], metadata={"help": "Conversation types to use."})
+ is_multimodal: bool = True
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ # * can refer to https://huggingface.co/docs/transformers/v4.28.1/en/main_classes/trainer#transformers.TrainingArgument
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ model_max_length: int = field(
+ default=2048,
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
+ )
+ model_debug: bool = field(default=False, metadata={"help": "Whether to use small model."}) # * whether to load checkpoints at the mo
+ fix_llm: bool = field(default=True, metadata={"help": "Whether to fix the LLM."})
+ fix_pointnet: bool = field(default=True, metadata={"help": "Whether to fix the PointNet."})
+
+ remove_unused_columns: bool = field(default=False)
+ force_fsdp: bool = field(default=False)
+
+ # * for two stage training
+ tune_mm_mlp_adapter: bool = field(default=True) # * set True when pre-training, and false when fine-tuning
+ stage_2: bool = field(default=False) # * set True when fine-tuning
+ pretrained_mm_mlp_adapter: Optional[str] = field(default=None) # * path to the pre-trained projector & output_embed & input_embed
+ detatch_point_token: bool = field(default=False) # * deprecated
+ # * point backbone ckpt path
+ point_backbone_ckpt: str = field(default=None)
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
+ output_dir: str):
+ """Collects the state dict and dump to disk."""
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {
+ key: value.cpu()
+ for key, value in state_dict.items()
+ }
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def train():
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ training_args.log_level = "info" # * default is passive(warning)
+ # * build logger
+ logger = build_logger(__name__, training_args.output_dir + '/train.log')
+
+ if training_args.model_debug:
+ # * do not load checkpoint, load from config
+ config = transformers.AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ )
+ model = PointLLMLlamaForCausalLM._from_config(config)
+ else:
+ model = PointLLMLlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ )
+
+ model.config.use_cache = False
+
+ if training_args.fix_llm:
+ # * This will fix all the parameters
+ logger.info("LLM is fixed. Fix_llm flag is set to True")
+ # * fix llama, lm_head, pointnet, projection layer here
+ model.requires_grad_(False)
+ model.get_model().fix_llm = True
+ model.get_model().point_proj.requires_grad_(True)
+ model.get_model().point_backbone.requires_grad_(True) # * set as True for fsdp, use fix_pointnet flag to control
+ else:
+ model.get_model().fix_llm = False
+ logger.warning("LLM is trainable. Fix_llm flag is set to False")
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+ if model_args.version == "v0" or "v0" in model_args.model_name_or_path:
+ raise ValueError("v0 is deprecated.")
+ else:
+ tokenizer.pad_token = tokenizer.unk_token
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"]
+
+ if not training_args.fix_pointnet:
+ # * not fix pointnet
+ logger.info("Point backbone is trainable. Fix_pointnet flag is set to False, pointnet grad will be recorded.")
+ model.get_model().fix_pointnet = False
+ else:
+ logger.info("Point backbone is fixed. Fix_pointnet flag is set to True, pointnet grad will not be recorded.")
+ model.get_model().fix_pointnet = True # * use with torch.inference_mode to control, not requires_grad for fsdp for second stage
+ if not training_args.stage_2:
+ logger.info("Set requires_grad of point backbone to False")
+ model.get_model().point_backbone.requires_grad_(False) # * fix pointnet for first stage, need for fsdp in stage2
+
+ if training_args.tune_mm_mlp_adapter:
+ # * not fix the projection layer
+ # * may need to set the embed_tokens to require_grad = True if added new tokens
+ # * this is done in initialize_tokenizer_point_backbone_config
+ logger.info("Point projection layer is trainable.")
+ else:
+ model.get_model().point_proj.requires_grad_(False)
+ logger.info("Point prejcetion layer is fixed.")
+
+ if not training_args.stage_2:
+ # * we assume in stage2, llm, point_backbone, and projection layer can be loaded from the model checkpoint
+ print(f"Default point_backbone_ckpt is {training_args.point_backbone_ckpt}.")
+ model.get_model().load_point_backbone_checkpoint(training_args.point_backbone_ckpt)
+ model.initialize_tokenizer_point_backbone_config(tokenizer=tokenizer, device=training_args.device, fix_llm=training_args.fix_llm)
+ else:
+ # * stage2
+ model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer=tokenizer)
+
+ point_backbone_config = model.get_model().point_backbone_config
+
+ data_args.point_token_len = point_backbone_config['point_token_len']
+ data_args.mm_use_point_start_end = point_backbone_config['mm_use_point_start_end']
+ data_args.point_backbone_config = point_backbone_config
+
+ params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
+ if len(params_no_grad) > 0:
+ if training_args.fsdp is not None and len(training_args.fsdp) > 0:
+ if len(params_no_grad) < 10:
+ print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
+ else:
+ print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
+ print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
+ print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
+
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
+ def patch_FSDP_use_orig_params(func):
+ def wrap_func(*args, **kwargs):
+ use_orig_params = kwargs.pop('use_orig_params', True)
+ return func(*args, **kwargs, use_orig_params=use_orig_params)
+ return wrap_func
+
+ FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
+
+ data_module = make_object_point_data_module(tokenizer=tokenizer,
+ data_args=data_args)
+
+ trainer = PointLLMTrainer(model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ **data_module)
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+ safe_save_model_for_hf_trainer(trainer=trainer,
+ output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/ThirdParty/PointLLM/pointllm/train/train_mem.py b/ThirdParty/PointLLM/pointllm/train/train_mem.py
new file mode 100644
index 0000000000000000000000000000000000000000..67d8035750cd9a463547eac788dc856e79375ad2
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/train/train_mem.py
@@ -0,0 +1,13 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
+
+# Need to call this before importing transformers.
+from pointllm.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
+
+replace_llama_attn_with_flash_attn()
+
+from pointllm.train.train import train
+
+if __name__ == "__main__":
+ train()
diff --git a/ThirdParty/PointLLM/pointllm/utils.py b/ThirdParty/PointLLM/pointllm/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..95a35e802b162dde6d4b83d50100b515113a3719
--- /dev/null
+++ b/ThirdParty/PointLLM/pointllm/utils.py
@@ -0,0 +1,154 @@
+import logging
+import logging.handlers
+import os
+import sys
+
+import requests
+
+import yaml
+from easydict import EasyDict
+
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+
+handler = None
+
+
+def merge_new_config(config, new_config):
+ for key, val in new_config.items():
+ if not isinstance(val, dict):
+ if key == '_base_':
+ with open(new_config['_base_'], 'r') as f:
+ try:
+ val = yaml.load(f, Loader=yaml.FullLoader)
+ except:
+ val = yaml.load(f)
+ config[key] = EasyDict()
+ merge_new_config(config[key], val)
+ else:
+ config[key] = val
+ continue
+ if key not in config:
+ config[key] = EasyDict()
+ merge_new_config(config[key], val)
+ return config
+
+def cfg_from_yaml_file(cfg_file):
+ config = EasyDict()
+ with open(cfg_file, 'r') as f:
+ new_config = yaml.load(f, Loader=yaml.FullLoader)
+ merge_new_config(config=config, new_config=new_config)
+ return config
+
+
+def build_logger(logger_name, logger_filepath):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ # * get the logger_file's directory, and create it if not exist
+ logger_filedir = os.path.dirname(logger_filepath)
+ os.makedirs(logger_filedir, exist_ok=True)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ logger_filepath, when='D', utc=True)
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ''
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ''
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == '\n':
+ self.logger.log(self.log_level, line.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != '':
+ self.logger.log(self.log_level, self.linebuf.rstrip())
+ self.linebuf = ''
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ headers = {"Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ flagged = False
+ except KeyError as e:
+ flagged = False
+
+ return flagged
+
+
+def pretty_print_semaphore(semaphore):
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
diff --git a/ThirdParty/PointLLM/pyproject.toml b/ThirdParty/PointLLM/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..4660f79fb94ad9363c9072525e3f876dda29c9e5
--- /dev/null
+++ b/ThirdParty/PointLLM/pyproject.toml
@@ -0,0 +1,31 @@
+[build-system]
+requires = ["setuptools>=61.0"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "pointllm"
+version = "0.1.2"
+description = "Empower large language models to understand point clouds."
+readme = "README.md"
+requires-python = ">=3.8"
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+]
+dependencies = [
+ "accelerate", "einops", "fastapi", "gradio", "markdown2[all]", "numpy",
+ "requests", "sentencepiece", "tokenizers==0.12.1",
+ "torch>=2.0", "torchvision", "uvicorn", "wandb",
+ "shortuuid",
+ "deepspeed", "peft",
+ "transformers @ git+https://github.com/huggingface/transformers.git@cae78c46",
+ "openai", "tqdm",
+ "easydict", "timm==0.4.12", "ftfy==6.0.1", "regex", "open3d==0.16.0", "h5py", "termcolor",
+ "plyfile", "nltk", "rouge", "scikit-learn", "py-rouge"
+]
+
+[tool.setuptools.packages.find]
+exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
+
+[tool.wheel]
+exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
diff --git a/ThirdParty/PointLLM/scripts/PointLLM_train_stage1.sh b/ThirdParty/PointLLM/scripts/PointLLM_train_stage1.sh
new file mode 100755
index 0000000000000000000000000000000000000000..392f9fec15811df40889690a21b15a9e29b61b3c
--- /dev/null
+++ b/ThirdParty/PointLLM/scripts/PointLLM_train_stage1.sh
@@ -0,0 +1,43 @@
+master_port=$((RANDOM % (65535 - 49152 + 1) + 49152))
+# Get the filename without extension
+filename=$(basename "$0" | cut -f 1 -d '.')
+
+dir_path=PointLLM
+model_name_or_path=checkpoints/PointLLM_7B_v1.1_init
+data_path=data/objaverse_data
+anno_path=data/anno_data/PointLLM_brief_description_660K_filtered.json # or PointLLM_brief_description_660K.json (including val sets)
+output_dir=outputs/PointLLM_train_stage1/$filename
+point_backbone_ckpt=$model_name_or_path/point_bert_v1.2.pt
+
+cd $dir_path
+
+PYTHONPATH=$dir_path:$PYTHONPATH \
+torchrun --nnodes=1 --nproc_per_node=8 --master_port=$master_port pointllm/train/train_mem.py \
+ --model_name_or_path $model_name_or_path \
+ --data_path $data_path \
+ --anno_path $anno_path \
+ --output_dir $output_dir \
+ --version v1 \
+ --model_max_length 2048 \
+ --num_train_epochs 3 \
+ --per_device_train_batch_size 16 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "no" \
+ --save_steps 2400 \
+ --save_total_limit 1 \
+ --learning_rate 2e-3 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --bf16 True \
+ --evaluation_strategy "no" \
+ --fix_llm True \
+ --fix_pointnet True \
+ --gradient_checkpointing True \
+ --report_to wandb \
+ --run_name $filename \
+ --point_backbone_ckpt $point_backbone_ckpt \
+ --use_color True
\ No newline at end of file
diff --git a/ThirdParty/PointLLM/scripts/PointLLM_train_stage2.sh b/ThirdParty/PointLLM/scripts/PointLLM_train_stage2.sh
new file mode 100755
index 0000000000000000000000000000000000000000..eda415daf5c96b6367aab51e26468ae49f06ed60
--- /dev/null
+++ b/ThirdParty/PointLLM/scripts/PointLLM_train_stage2.sh
@@ -0,0 +1,46 @@
+master_port=$((RANDOM % (65535 - 49152 + 1) + 49152))
+# Get the filename without extension
+filename=$(basename "$0" | cut -f 1 -d '.')
+
+dir_path=PointLLM
+
+model_name_or_path=outputs/PointLLM_train_stage1/PointLLM_train_stage1 # Path to the output dir of stage 1 training
+data_path=data/objaverse_data
+anno_path=data/anno_data/PointLLM_complex_instruction_70K.json
+output_dir=outputs/PointLLM_train_stage2/$filename
+
+cd $dir_path
+
+PYTHONPATH=$dir_path:$PYTHONPATH \
+torchrun --nnodes=1 --nproc_per_node=8 --master_port=$master_port pointllm/train/train_mem.py \
+ --model_name_or_path $model_name_or_path \
+ --data_path $data_path \
+ --anno_path $anno_path \
+ --output_dir $output_dir \
+ --version v1 \
+ --model_max_length 2048 \
+ --num_train_epochs 3 \
+ --per_device_train_batch_size 4 \
+ --per_device_eval_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --eval_steps 100 \
+ --save_strategy "no" \
+ --save_steps 2400 \
+ --save_total_limit 1 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --bf16 True \
+ --fix_llm False \
+ --fix_pointnet True \
+ --report_to wandb \
+ --run_name $filename \
+ --gradient_checkpointing True \
+ --stage_2 True \
+ --fsdp "full_shard auto_wrap" \
+ --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
+ --conversation_types "detailed_description" "single_round" "multi_round" \
+ --use_color True
\ No newline at end of file
diff --git a/ThirdParty/Rignet_utils/Rignet_loss.py b/ThirdParty/Rignet_utils/Rignet_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0803afe7d3d2c5b35830400825d8139acc28310
--- /dev/null
+++ b/ThirdParty/Rignet_utils/Rignet_loss.py
@@ -0,0 +1,163 @@
+#-------------------------------------------------------------------------------
+# Name: utils.py
+# Purpose: utilize for Loss function in RigNet
+# RigNet Copyright 2020 University of Massachusetts
+# RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
+# Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
+#-------------------------------------------------------------------------------
+
+
+from apted import APTED, Config
+import numpy as np
+
+class CustomConfig(Config):
+ valuecls = float
+
+ def rename(self, node1, node2):
+ """Compares attribute .value of trees"""
+ # return 1 if node1.value != node2.value else 0
+ # if not node1 or not node2:
+ # return 1.0
+ # return np.sqrt(np.sum((np.array(node1.pos) - np.array(node2.pos))**2))
+ return 0
+
+ def children(self, node):
+ """Get left and right children of binary tree"""
+ # return [x for x in (node.left, node.right) if x]
+ if not node:
+ return list()
+ else:
+ return node.children
+
+
+def getJointNum(skel):
+ this_level = [skel.root]
+ n_joint = 1
+ while this_level:
+ next_level = []
+ for p_node in this_level:
+ n_joint += len(p_node.children)
+ next_level += p_node.children
+ this_level = next_level
+ return n_joint
+
+
+def dist_pts2bone(pts, pos_1, pos_2):
+ l2 = np.sum((pos_2 - pos_1) ** 2)
+ if l2 < 1e-10:
+ dist_to_lineseg = np.linalg.norm(pts - pos_1, axis=1)
+ dist_proj = np.linalg.norm(pts - pos_1, axis=1)
+ else:
+ t_ = np.sum((pts - pos_1[np.newaxis, :]) * (pos_2 - pos_1), axis=1) / l2
+ t = np.clip(t_, 0, 1)
+ t_pos = pos_1[np.newaxis, :] + t[:, np.newaxis] * (pos_2 - pos_1)[np.newaxis, :]
+ lineseg_len = np.linalg.norm(pos_2 - pos_1)
+ dist_proj = np.zeros(len(t_))
+ dist_proj[np.argwhere(t_ < 0.5).squeeze()] = np.abs(t_[np.argwhere(t_ < 0.5).squeeze()] - 0.0) * lineseg_len
+ dist_proj[np.argwhere(t_ >= 0.5).squeeze()] = np.abs(t_[np.argwhere(t_ >= 0.5).squeeze()] - 1.0) * lineseg_len
+ dist_to_lineseg = np.linalg.norm(pts - t_pos, axis=1)
+ return dist_to_lineseg, dist_proj
+
+
+def chamfer_dist(pt1, pt2):
+ pt1 = pt1[np.newaxis, :, :]
+ pt2 = pt2[:, np.newaxis, :]
+ dist = np.sqrt(np.sum((pt1 - pt2) ** 2, axis=2))
+ min_left = np.mean(np.min(dist, axis=0))
+ min_right = np.mean(np.min(dist, axis=1))
+ #print(min_left, min_right)
+ return (min_left + min_right) / 2
+
+
+def oneway_chamfer(pt_src, pt_dst):
+ pt1 = pt_src[np.newaxis, :, :]
+ pt2 = pt_dst[:, np.newaxis, :]
+ dist = np.sqrt(np.sum((pt1 - pt2) ** 2, axis=2))
+ avg_dist = np.mean(np.min(dist, axis=0))
+ return avg_dist
+
+
+def getJointArr(skel):
+ joints = []
+ this_level = [skel.root]
+ while this_level:
+ next_level = []
+ for p_node in this_level:
+ joint_ = np.array(p_node.pos)
+ joint_ = joint_[np.newaxis, :]
+ joints.append(joint_)
+ next_level += p_node.children
+ this_level = next_level
+ joints = np.concatenate(joints, axis=0)
+ return joints
+
+
+def edit_dist(tree1, tree2):
+ #n_joint1 = getJointNum(tree2)
+ #n_joint2 = getJointNum(tree2)
+ apted = APTED(tree1.root, tree2.root, CustomConfig())
+ ted = apted.compute_edit_distance()
+ #ted /= max(n_joint1, n_joint2)
+ return ted
+
+
+def tree_dist(tree1, tree2, ted_weight):
+ # get edit distance
+ ted = edit_dist(tree1, tree2)
+
+ # get chamfer distance
+ joint_arr_1 = getJointArr(tree1)
+ joint_arr_2 = getJointArr(tree2)
+ cd = chamfer_dist(joint_arr_1, joint_arr_2)
+
+ return (1-ted_weight)*cd + ted_weight * ted
+
+
+def sample_bone(p_pos, ch_pos):
+ ray = ch_pos - p_pos
+ bone_length = np.sqrt(np.sum((p_pos - ch_pos) ** 2))
+ num_step = np.round(bone_length / 0.005)
+ i_step = np.arange(0, num_step + 1)
+ unit_step = ray / (num_step + 1e-30)
+ unit_step = np.repeat(unit_step, num_step+1, axis=0)
+ res = p_pos + unit_step * i_step[:, np.newaxis]
+ return res
+
+
+def sample_skel(skel):
+ bone_sample = []
+ this_level = [skel.root]
+ while this_level:
+ next_level = []
+ for p_node in this_level:
+ p_pos = np.array([p_node.pos])
+ next_level += p_node.children
+ for c_node in p_node.children:
+ ch_pos = np.array([c_node.pos])
+ res = sample_bone(p_pos, ch_pos)
+ bone_sample.append(res)
+ this_level = next_level
+ bone_sample = np.concatenate(bone_sample, axis=0)
+ return bone_sample
+
+
+def bone2bone_chamfer_dist(skel_1, skel_2):
+ bone_sample_1 = sample_skel(skel_1)
+ bone_sample_2 = sample_skel(skel_2)
+ pt1 = bone_sample_1[np.newaxis, :, :]
+ pt2 = bone_sample_2[:, np.newaxis, :]
+ dist = np.sqrt(np.sum((pt1 - pt2) ** 2, axis=2))
+ min_left = np.mean(np.min(dist, axis=0))
+ min_right = np.mean(np.min(dist, axis=1))
+ # print(min_left, min_right)
+ return (min_left + min_right) / 2
+
+
+def joint2bone_chamfer_dist(skel1, skel2):
+ bone_sample_1 = sample_skel(skel1)
+ bone_sample_2 = sample_skel(skel2)
+ joint_1 = getJointArr(skel1)
+ joint_2 = getJointArr(skel2)
+ dist1 = oneway_chamfer(joint_1, bone_sample_2)
+ dist2 = oneway_chamfer(joint_2, bone_sample_1)
+ return (dist1 + dist2) / 2
\ No newline at end of file
diff --git a/ThirdParty/Rignet_utils/__init__.py b/ThirdParty/Rignet_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ThirdParty/Rignet_utils/__pycache__/__init__.cpython-310.pyc b/ThirdParty/Rignet_utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0abf79f5166fad1a19e8c11107ffb93608c3451f
Binary files /dev/null and b/ThirdParty/Rignet_utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ThirdParty/Rignet_utils/__pycache__/binvox_rw.cpython-310.pyc b/ThirdParty/Rignet_utils/__pycache__/binvox_rw.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6b29eaa16a722bb7a1216b6dd954d7dedbbe608
Binary files /dev/null and b/ThirdParty/Rignet_utils/__pycache__/binvox_rw.cpython-310.pyc differ
diff --git a/ThirdParty/Rignet_utils/binvox_rw.py b/ThirdParty/Rignet_utils/binvox_rw.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e42024802c0428438d4ecf42a07d68cf285008f
--- /dev/null
+++ b/ThirdParty/Rignet_utils/binvox_rw.py
@@ -0,0 +1,246 @@
+# Copyright (C) 2012 Daniel Maturana
+# This file is part of binvox-rw-py.
+#
+# binvox-rw-py is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# binvox-rw-py is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with binvox-rw-py. If not, see .
+#
+
+
+import numpy as np
+import struct
+
+
+class Voxels(object):
+ """ Holds a binvox model.
+ data is either a three-dimensional numpy boolean array (dense representation)
+ or a two-dimensional numpy float array (coordinate representation).
+
+ dims, translate and scale are the model metadata.
+
+ dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model.
+
+ scale and translate relate the voxels to the original model coordinates.
+
+ To translate voxel coordinates i, j, k to original coordinates x, y, z:
+
+ x_n = (i+.5)/dims[0]
+ y_n = (j+.5)/dims[1]
+ z_n = (k+.5)/dims[2]
+ x = scale*x_n + translate[0]
+ y = scale*y_n + translate[1]
+ z = scale*z_n + translate[2]
+
+ """
+
+ def __init__(self, data, dims, translate, scale, axis_order):
+ self.data = data
+ self.dims = dims
+ self.translate = translate
+ self.scale = scale
+ assert (axis_order in ('xzy', 'xyz'))
+ self.axis_order = axis_order
+
+ def clone(self):
+ data = self.data.copy()
+ dims = self.dims[:]
+ translate = self.translate[:]
+ return Voxels(data, dims, translate, self.scale, self.axis_order)
+
+ def write(self, fp):
+ write(self, fp)
+
+def read_header(fp):
+ """ Read binvox header. Mostly meant for internal use.
+ """
+ line = fp.readline().strip()
+ if not line.startswith(b'#binvox'):
+ raise IOError('Not a binvox file')
+ dims = list(map(int, fp.readline().strip().split(b' ')[1:]))
+ translate = list(map(float, fp.readline().strip().split(b' ')[1:]))
+ scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0]
+ line = fp.readline()
+
+ return dims, translate, scale
+
+def read_as_3d_array(fp, fix_coords=True):
+ """ Read binary binvox format as array.
+
+ Returns the model with accompanying metadata.
+
+ Voxels are stored in a three-dimensional numpy array, which is simple and
+ direct, but may use a lot of memory for large models. (Storage requirements
+ are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy
+ boolean arrays use a byte per element).
+
+ Doesn't do any checks on input except for the '#binvox' line.
+ """
+ dims, translate, scale = read_header(fp)
+ raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
+ # if just using reshape() on the raw data:
+ # indexing the array as array[i,j,k], the indices map into the
+ # coords as:
+ # i -> x
+ # j -> z
+ # k -> y
+ # if fix_coords is true, then data is rearranged so that
+ # mapping is
+ # i -> x
+ # j -> y
+ # k -> z
+ values, counts = raw_data[::2], raw_data[1::2]
+ data = np.repeat(values, counts).astype(bool)
+ data = data.reshape(dims)
+ if fix_coords:
+ # xzy to xyz TODO the right thing
+ data = np.transpose(data, (0, 2, 1))
+ axis_order = 'xyz'
+ else:
+ axis_order = 'xzy'
+ return Voxels(data, dims, translate, scale, axis_order)
+
+def read_as_coord_array(fp, fix_coords=True):
+ """ Read binary binvox format as coordinates.
+
+ Returns binvox model with voxels in a "coordinate" representation, i.e. an
+ 3 x N array where N is the number of nonzero voxels. Each column
+ corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates
+ of the voxel. (The odd ordering is due to the way binvox format lays out
+ data). Note that coordinates refer to the binvox voxels, without any
+ scaling or translation.
+
+ Use this to save memory if your model is very sparse (mostly empty).
+
+ Doesn't do any checks on input except for the '#binvox' line.
+ """
+ dims, translate, scale = read_header(fp)
+ raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
+
+ values, counts = raw_data[::2], raw_data[1::2]
+
+ sz = np.prod(dims)
+ index, end_index = 0, 0
+ end_indices = np.cumsum(counts)
+ indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype)
+
+ values = values.astype(bool)
+ indices = indices[values]
+ end_indices = end_indices[values]
+
+ nz_voxels = []
+ for index, end_index in zip(indices, end_indices):
+ nz_voxels.extend(range(index, end_index))
+ nz_voxels = np.array(nz_voxels)
+ # TODO are these dims correct?
+ # according to docs,
+ # index = x * wxh + z * width + y; // wxh = width * height = d * d
+
+ x = nz_voxels / (dims[0]*dims[1])
+ zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y
+ z = zwpy / dims[0]
+ y = zwpy % dims[0]
+ if fix_coords:
+ data = np.vstack((x, y, z))
+ axis_order = 'xyz'
+ else:
+ data = np.vstack((x, z, y))
+ axis_order = 'xzy'
+
+ #return Voxels(data, dims, translate, scale, axis_order)
+ return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order)
+
+def dense_to_sparse(voxel_data, dtype=int):
+ """ From dense representation to sparse (coordinate) representation.
+ No coordinate reordering.
+ """
+ if voxel_data.ndim!=3:
+ raise ValueError('voxel_data is wrong shape; should be 3D array.')
+ return np.asarray(np.nonzero(voxel_data), dtype)
+
+def sparse_to_dense(voxel_data, dims, dtype=bool):
+ if voxel_data.ndim!=2 or voxel_data.shape[0]!=3:
+ raise ValueError('voxel_data is wrong shape; should be 3xN array.')
+ if np.isscalar(dims):
+ dims = [dims]*3
+ dims = np.atleast_2d(dims).T
+ # truncate to integers
+ xyz = voxel_data.astype(int)
+ # discard voxels that fall outside dims
+ valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0)
+ xyz = xyz[:,valid_ix]
+ out = np.zeros(dims.flatten(), dtype=dtype)
+ out[tuple(xyz)] = True
+ return out
+
+#def get_linear_index(x, y, z, dims):
+ #""" Assuming xzy order. (y increasing fastest.
+ #TODO ensure this is right when dims are not all same
+ #"""
+ #return x*(dims[1]*dims[2]) + z*dims[1] + y
+
+def bwrite(fp,s):
+ fp.write(s.encode())
+
+def write_pair(fp,state, ctr):
+ fp.write(struct.pack('B',state))
+ fp.write(struct.pack('B',ctr))
+
+def write(voxel_model, fp):
+ """ Write binary binvox format.
+
+ Note that when saving a model in sparse (coordinate) format, it is first
+ converted to dense format.
+
+ Doesn't check if the model is 'sane'.
+
+ """
+ if voxel_model.data.ndim==2:
+ # TODO avoid conversion to dense
+ dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims)
+ else:
+ dense_voxel_data = voxel_model.data
+
+ bwrite(fp, '#binvox 1\n')
+ bwrite(fp, 'dim ' + ' '.join(map(str, voxel_model.dims)) + '\n')
+ bwrite(fp, 'translate ' + ' '.join(map(str, voxel_model.translate)) + '\n')
+ bwrite(fp, 'scale ' + str(voxel_model.scale) + '\n')
+ bwrite(fp, 'data\n')
+ if not voxel_model.axis_order in ('xzy', 'xyz'):
+ raise ValueError('Unsupported voxel model axis order')
+
+ if voxel_model.axis_order=='xzy':
+ voxels_flat = dense_voxel_data.flatten()
+ elif voxel_model.axis_order=='xyz':
+ voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()
+
+ # keep a sort of state machine for writing run length encoding
+ state = voxels_flat[0]
+ ctr = 0
+ for c in voxels_flat:
+ if c==state:
+ ctr += 1
+ # if ctr hits max, dump
+ if ctr==255:
+ write_pair(fp, state, ctr)
+ ctr = 0
+ else:
+ # if switch state, dump
+ write_pair(fp, state, ctr)
+ state = c
+ ctr = 1
+ # flush out remainders
+ if ctr > 0:
+ write_pair(fp, state, ctr)
+
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()
diff --git a/ThirdParty/Rignet_utils/mst_utils.py b/ThirdParty/Rignet_utils/mst_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf8b50e32f612736618cf93698750095efec2897
--- /dev/null
+++ b/ThirdParty/Rignet_utils/mst_utils.py
@@ -0,0 +1,179 @@
+#-------------------------------------------------------------------------------
+# Name: mst_utils.py
+# Purpose: utilize functions for skeleton generation
+# RigNet Copyright 2020 University of Massachusetts
+# RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
+# Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
+#-------------------------------------------------------------------------------
+
+import sys
+import numpy as np
+from .rig_parser import TreeNode
+from .rig_parser import Skel
+import torch
+
+def inside_check(pts, vox):
+ """
+ Check where points are inside or outside the mesh based on its voxelization.
+ :param pts: points to be checked
+ :param vox: voxelized mesh
+ :return: internal points, and index of them in the input array.
+ """
+ vc = (pts - vox.translate) / vox.scale * vox.dims[0]
+ vc = np.round(vc).astype(int)
+ ind1 = np.logical_and(np.all(vc >= 0, axis=1), np.all(vc < 88, axis=1))
+ vc = np.clip(vc, 0, 87)
+ ind2 = vox.data[vc[:, 0], vc[:, 1], vc[:, 2]]
+ ind = np.logical_and(ind1, ind2)
+ pts = pts[ind]
+ return pts, np.argwhere(ind).squeeze()
+
+
+def sample_on_bone(p_pos, ch_pos):
+ """
+ sample points on a bone
+ :param p_pos: parent joint position
+ :param ch_pos: child joint position
+ :return: a array of samples on this bone.
+ """
+ ray = ch_pos - p_pos
+ bone_length = np.sqrt(np.sum((p_pos - ch_pos) ** 2))
+ num_step = np.round(bone_length / 0.01)
+ i_step = np.arange(1, num_step + 1)
+ unit_step = ray / (num_step + 1e-30)
+ unit_step = np.repeat(unit_step[np.newaxis, :], num_step, axis=0)
+ res = p_pos + unit_step * i_step[:, np.newaxis]
+ return res
+
+
+def minKey(key, mstSet, nV):
+ # Initilaize min value
+ min = sys.maxsize
+ for v in range(nV):
+ if key[v] < min and mstSet[v] == False:
+ min = key[v]
+ min_index = v
+ return min_index
+
+def primMST_normal(graph, init_id, normal_matrix):
+ """
+ Modified Prim's algorithm to generate a minimum spanning tree (MST).
+ :param graph: pairwise cost matrix
+ :param init_id: init node ID as root
+ :return: parent array, key array, init_id
+ """
+ nV = graph.shape[0]
+ key = [sys.maxsize] * nV
+ parent = [None] * nV
+ mstSet = [False] * nV
+ key[init_id] = 0
+ parent[init_id] = -1
+ previous_normal = np.zeros((nV, 3))
+
+ while not all(mstSet):
+ u = minKey(key, mstSet, nV)
+ mstSet[u] = True
+ if parent[u] >= 0:
+ previous_normal[u] = normal_matrix[u, parent[u]]
+ updated_normal = np.dot(previous_normal[u], normal_matrix[u, :].T) #1*n
+ updated_normal[updated_normal<0]=0
+ # print('updated_normal',updated_normal.shape)
+ graph[u, :] = graph[u, :] +(1e8*updated_normal**2+1)
+ graph[:, u] = graph[:, u] +(1e8*updated_normal**2+1)
+
+ for v in range(nV):
+
+ if graph[u, v] > 0 and mstSet[v] is False and key[v] > graph[u, v]:
+ key[v] = graph[u, v]
+ parent[v] = u
+
+
+ return parent, key, init_id
+
+
+def loadSkel_recur(p_node, parent_id, joint_name, joint_pos, parent):
+ """
+ Converst prim algorithm result to our skel/info format recursively
+ :param p_node: Root node
+ :param parent_id: parent name of current step of recursion.
+ :param joint_name: list of joint names
+ :param joint_pos: joint positions
+ :param parent: parent index returned by prim alg.
+ :return: p_node (root) will be expanded to linked with all joints
+ """
+ for i in range(len(parent)):
+ if parent[i] == parent_id:
+ if joint_name is not None:
+ ch_node = TreeNode(joint_name[i], tuple(joint_pos[i]))
+ else:
+ ch_node = TreeNode('joint_{}'.format(i), tuple(joint_pos[i]))
+ p_node.children.append(ch_node)
+ ch_node.parent = p_node
+ loadSkel_recur(ch_node, i, joint_name, joint_pos, parent)
+
+
+def unique_rows(a):
+ """
+ remove repeat rows from a numpy array
+ """
+ a = np.ascontiguousarray(a)
+ unique_a = np.unique(a.view([('', a.dtype)]*a.shape[1]))
+ return unique_a.view(a.dtype).reshape((unique_a.shape[0], a.shape[1]))
+
+
+def increase_cost_for_outside_bone(cost_matrix, joint_pos, vox):
+ """
+ increase connectivity cost for bones outside the meshs
+ """
+ for i in range(len(joint_pos)):
+ for j in range(i+1, len(joint_pos)):
+ bone_samples = sample_on_bone(joint_pos[i], joint_pos[j])
+ bone_samples_vox = (bone_samples - vox.translate) / vox.scale * vox.dims[0]
+ bone_samples_vox = np.round(bone_samples_vox).astype(int)
+
+ ind1 = np.logical_and(np.all(bone_samples_vox >= 0, axis=1), np.all(bone_samples_vox < vox.dims[0], axis=1))
+ bone_samples_vox = np.clip(bone_samples_vox, 0, vox.dims[0]-1)
+ ind2 = vox.data[bone_samples_vox[:, 0], bone_samples_vox[:, 1], bone_samples_vox[:, 2]]
+ in_flags = np.logical_and(ind1, ind2)
+ outside_bone_sample = np.sum(in_flags == False)
+
+ if outside_bone_sample > 1:
+ cost_matrix[i, j] = 2 * outside_bone_sample
+ cost_matrix[j, i] = 2 * outside_bone_sample
+ if np.abs(joint_pos[i, 0]) < 2e-2 and np.abs(joint_pos[j, 0]) < 2e-2:
+ cost_matrix[i, j] *= 0.5
+ cost_matrix[j, i] *= 0.5
+ return cost_matrix
+
+def increase_cost_for_outside_bone_tensor(cost_matrix, joint_pos, vox,resolution=64):
+ """
+ increase connectivity cost for bones outside the meshs
+ vox is a tensor with size(N,3), N is the number of voxels that inside the mesh, range (0,64)
+ """
+
+ vox = torch.clamp(vox, 0, resolution-1).long()
+ for i in range(len(joint_pos)):
+ for j in range(i+1, len(joint_pos)):
+ bone_samples = sample_on_bone(joint_pos[i], joint_pos[j]) # return coordinates of points on the bone
+ bone_samples_vox = bone_samples * (resolution/2) + (resolution/2)
+ bone_samples_vox = np.round(bone_samples_vox).astype(int)
+ bone_samples_vox = np.clip(bone_samples_vox, 0, resolution-1)
+
+ vox_remap = torch.zeros((resolution,resolution,resolution))
+ vox_remap[vox[:,0],vox[:,1],vox[:,2]] = 1
+ vox_remap = vox_remap.numpy()
+ inside_index = vox_remap[bone_samples_vox[:,0],bone_samples_vox[:,1],bone_samples_vox[:,2]]
+ outside_bone_sample = np.sum(inside_index == 0)
+
+
+ # check the intersection of the bone with the mesh
+
+ if outside_bone_sample > 1:
+ cost_matrix[i, j] = 2 * outside_bone_sample
+ cost_matrix[j, i] = 2 * outside_bone_sample
+ if np.abs(joint_pos[i, 0]) < 2e-2 and np.abs(joint_pos[j, 0]) < 2e-2:
+ cost_matrix[i, j] *= 0.5
+ cost_matrix[j, i] *= 0.5
+ return cost_matrix
+
+
diff --git a/ThirdParty/Rignet_utils/rig_parser.py b/ThirdParty/Rignet_utils/rig_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..af6c2ff6659abc88bdd0adc43931d17457b63b19
--- /dev/null
+++ b/ThirdParty/Rignet_utils/rig_parser.py
@@ -0,0 +1,268 @@
+#-------------------------------------------------------------------------------
+# Name: rig_parser.py
+# Purpose: classes for skeleton and rig
+# RigNet Copyright 2020 University of Massachusetts
+# RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
+# Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
+#-------------------------------------------------------------------------------
+
+import numpy as np
+
+try:
+ import Queue as Q # ver. < 3.0
+except ImportError:
+ import queue as Q
+
+class Node(object):
+ def __init__(self, name, pos):
+ self.name = name
+ self.pos = pos
+
+
+class TreeNode(Node):
+ def __init__(self, name, pos):
+ super(TreeNode, self).__init__(name, pos)
+ self.children = []
+ self.parent = None
+
+class Info:
+ """
+ Wrap class for rig information
+ """
+ def __init__(self, filename=None):
+ self.joint_pos = {}
+ self.joint_skin = []
+ self.root = None
+ if filename is not None:
+ self.load(filename)
+
+ def load(self, filename):
+ with open(filename, 'r') as f_txt:
+ lines = f_txt.readlines()
+ for line in lines:
+ word = line.split()
+ if word[0] == 'joints':
+ self.joint_pos[word[1]] = [float(word[2]), float(word[3]), float(word[4])]
+ elif word[0] == 'root':
+ root_pos = self.joint_pos[word[1]]
+ self.root = TreeNode(word[1], (root_pos[0], root_pos[1], root_pos[2]))
+ elif word[0] == 'skin':
+ skin_item = word[1:]
+ self.joint_skin.append(skin_item)
+ self.loadHierarchy_recur(self.root, lines, self.joint_pos)
+
+ def loadHierarchy_recur(self, node, lines, joint_pos):
+ for li in lines:
+ if li.split()[0] == 'hier' and li.split()[1] == node.name:
+ pos = joint_pos[li.split()[2]]
+ ch_node = TreeNode(li.split()[2], tuple(pos))
+ node.children.append(ch_node)
+ ch_node.parent = node
+ self.loadHierarchy_recur(ch_node, lines, joint_pos)
+
+ def save(self, filename):
+ with open(filename, 'w') as file_info:
+ for key, val in self.joint_pos.items():
+ file_info.write(
+ 'joints {0} {1:.8f} {2:.8f} {3:.8f}\n'.format(key, val[0], val[1], val[2]))
+ file_info.write('root {}\n'.format(self.root.name))
+
+ for skw in self.joint_skin:
+ cur_line = 'skin {0} '.format(skw[0])
+ for cur_j in range(1, len(skw), 2):
+ cur_line += '{0} {1:.4f} '.format(skw[cur_j], float(skw[cur_j+1]))
+ cur_line += '\n'
+ file_info.write(cur_line)
+
+ this_level = self.root.children
+ while this_level:
+ next_level = []
+ for p_node in this_level:
+ file_info.write('hier {0} {1}\n'.format(p_node.parent.name, p_node.name))
+ next_level += p_node.children
+ this_level = next_level
+ # return a numpy array skin_relation, where skin_relation[i, j] = 1 if joint i is skinned to joint j
+
+ def get_skin_dict(self, filename):
+ skinning_dict = {}
+ with open (filename, 'r') as f:
+ lines = f.readlines()
+ skin_lines = [line for line in lines if line.startswith('skin')]
+ vertex_num = len(skin_lines)
+ for line in skin_lines:
+ word = line.split()
+ word = word[1:]
+ skin_vertex = {}
+ for i in range(1,len(word),2):
+ skin_vertex[word[i]] = float(word[i+1])
+ skinning_dict[word[0]] = skin_vertex
+ return skinning_dict,vertex_num
+
+ def save_as_skel_format(self, filename):
+ fout = open(filename, 'w')
+ this_level = [self.root]
+ hier_level = 1
+ while this_level:
+ next_level = []
+ for p_node in this_level:
+ pos = p_node.pos
+ parent = p_node.parent.name if p_node.parent is not None else 'None'
+ line = '{0} {1} {2:8f} {3:8f} {4:8f} {5}\n'.format(hier_level, p_node.name, pos[0], pos[1], pos[2],
+ parent)
+ fout.write(line)
+ for c_node in p_node.children:
+ next_level.append(c_node)
+ this_level = next_level
+ hier_level += 1
+ fout.close()
+
+ def normalize(self, scale, trans):
+ for k, v in self.joint_pos.items():
+ self.joint_pos[k] /= scale
+ self.joint_pos[k] -= trans
+
+
+ this_level = [self.root]
+ while this_level:
+ next_level = []
+ for node in this_level:
+ node.pos /= scale
+ node.pos = (node.pos[0] - trans[0], node.pos[1] - trans[1], node.pos[2] - trans[2])
+ for ch in node.children:
+ next_level.append(ch)
+ this_level = next_level
+
+ def get_joint_dict(self):
+ joint_dict = {}
+ this_level = [self.root]
+ while this_level:
+ next_level = []
+ for node in this_level:
+ joint_dict[node.name] = node.pos
+ next_level += node.children
+ this_level = next_level
+ return joint_dict
+
+ def adjacent_matrix(self):
+ joint_pos = self.get_joint_dict()
+ joint_name_list = list(joint_pos.keys())
+ num_joint = len(joint_pos)
+ adj_matrix = np.zeros((num_joint, num_joint))
+ this_level = [self.root]
+ while this_level:
+ next_level = []
+ for p_node in this_level:
+ for c_node in p_node.children:
+ index_parent = joint_name_list.index(p_node.name)
+ index_children = joint_name_list.index(c_node.name)
+ adj_matrix[index_parent, index_children] = 1.
+ next_level += p_node.children
+ this_level = next_level
+ adj_matrix = adj_matrix + adj_matrix.transpose()
+ return adj_matrix
+
+
+class Skel:
+ """
+ Wrap class for skeleton topology
+ """
+ def __init__(self, filename=None):
+ self.root = None
+ if filename is not None:
+ self.load(filename)
+
+ def load(self, filename):
+ with open(filename, 'r') as fin:
+ lines = fin.readlines()
+ for li in lines:
+ words = li.split()
+ if words[5] == "None":
+ self.root = TreeNode(words[1], (float(words[2]), float(words[3]), float(words[4])))
+ if len(words) == 7:
+ has_order = True
+ self.root.order = int(words[6])
+ else:
+ has_order = False
+ break
+ self.loadSkel_recur(self.root, lines, has_order)
+
+ def loadSkel_recur(self, node, lines, has_order):
+ if has_order:
+ ch_queue = Q.PriorityQueue()
+ for li in lines:
+ words = li.split()
+ if words[5] == node.name:
+ ch_queue.put((int(li.split()[6]), li))
+ while not ch_queue.empty():
+ item = ch_queue.get()
+ li = item[1]
+ ch_node = TreeNode(li.split()[1], (float(li.split()[2]), float(li.split()[3]), float(li.split()[4])))
+ ch_node.order = int(li.split()[6])
+ node.children.append(ch_node)
+ ch_node.parent = node
+ self.loadSkel_recur(ch_node, lines, has_order)
+ else:
+ for li in lines:
+ words = li.split()
+ if words[5] == node.name:
+ ch_node = TreeNode(words[1], (float(words[2]), float(words[3]), float(words[4])))
+ node.children.append(ch_node)
+ ch_node.parent = node
+ self.loadSkel_recur(ch_node, lines, has_order)
+
+ def save(self, filename):
+ fout = open(filename, 'w')
+ this_level = [self.root]
+ hier_level = 1
+ while this_level:
+ next_level = []
+ for p_node in this_level:
+ pos = p_node.pos
+ parent = p_node.parent.name if p_node.parent is not None else 'None'
+ line = '{0} {1} {2:8f} {3:8f} {4:8f} {5}\n'.format(hier_level, p_node.name, pos[0], pos[1], pos[2], parent)
+ fout.write(line)
+ for c_node in p_node.children:
+ next_level.append(c_node)
+ this_level = next_level
+ hier_level += 1
+ fout.close()
+
+ def normalize(self, scale, trans):
+ this_level = [self.root]
+ while this_level:
+ next_level = []
+ for node in this_level:
+ node.pos /= scale
+ node.pos = (node.pos[0] - trans[0], node.pos[1] - trans[1], node.pos[2] - trans[2])
+ for ch in node.children:
+ next_level.append(ch)
+ this_level = next_level
+
+ def get_joint_pos(self):
+ joint_pos = {}
+ this_level = [self.root]
+ while this_level:
+ next_level = []
+ for node in this_level:
+ joint_pos[node.name] = node.pos
+ next_level += node.children
+ this_level = next_level
+ return joint_pos
+
+ def adjacent_matrix(self):
+ joint_pos = self.get_joint_pos()
+ joint_name_list = list(joint_pos.keys())
+ num_joint = len(joint_pos)
+ adj_matrix = np.zeros((num_joint, num_joint))
+ this_level = [self.root]
+ while this_level:
+ next_level = []
+ for p_node in this_level:
+ for c_node in p_node.children:
+ index_parent = joint_name_list.index(p_node.name)
+ index_children = joint_name_list.index(c_node.name)
+ adj_matrix[index_parent, index_children] = 1.
+ next_level += p_node.children
+ this_level = next_level
+ adj_matrix = adj_matrix + adj_matrix.transpose()
+ return adj_matrix
diff --git a/ThirdParty/Rignet_utils/utils.py b/ThirdParty/Rignet_utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..75dc9bea7edb077149b0ddc5bf0e0a1abf0498e8
--- /dev/null
+++ b/ThirdParty/Rignet_utils/utils.py
@@ -0,0 +1,55 @@
+#-------------------------------------------------------------------------------
+# Name: utils.py
+# Purpose: utilize functions for skeleton generation
+# RigNet Copyright 2020 University of Massachusetts
+# RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
+# Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
+#-------------------------------------------------------------------------------
+
+import numpy as np
+from .rig_parser import Info, TreeNode
+from .mst_utils import increase_cost_for_outside_bone, loadSkel_recur,primMST_normal, increase_cost_for_outside_bone_tensor
+import trimesh
+import torch
+
+def get_skel(pred_joints, prob_matrix,vox):
+ "use predict connection which indicte the connection prob between joints to find the root joints,whihc is the joint with the highest connection prob with itself"
+ root_id = np.argmax(np.diag(prob_matrix))
+ # set the digonal to 0 and normalize the prob_matrix
+ np.fill_diagonal(prob_matrix, 0)
+ prob_matrix = prob_matrix / (np.sum(prob_matrix, axis=1, keepdims=True)+1e-6)
+
+ cost_matrix = -np.log(prob_matrix + 1e-10)
+ if torch.is_tensor(vox):
+ cost_matrix = increase_cost_for_outside_bone_tensor(cost_matrix, pred_joints, vox)
+ else:
+ cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox)
+
+ pred_joints = np.array(pred_joints)
+
+ # Create a matrix of shape (n, n, 3) where each element is the difference pred_joints[j] - pred_joints[i]
+ diff_matrix = pred_joints[:, np.newaxis, :] - pred_joints[np.newaxis, :, :]
+ norms = np.linalg.norm(diff_matrix, axis=2, keepdims=True)
+ norms[norms == 0] = 1
+ normal_matrix = diff_matrix / norms
+ np.fill_diagonal(normal_matrix[:, :, 0], 0)
+ np.fill_diagonal(normal_matrix[:, :, 1], 0)
+ np.fill_diagonal(normal_matrix[:, :, 2], 0)
+
+ pred_skel = Info()
+
+ parent, key, root_id = primMST_normal(cost_matrix, root_id, normal_matrix)
+
+ for i in range(len(parent)):
+ if parent[i] == -1:
+ pred_skel.root = TreeNode('root', tuple(pred_joints[i]))
+ break
+ loadSkel_recur(pred_skel.root, i, None, pred_joints, parent)
+ pred_skel.joint_pos = pred_skel.get_joint_dict()
+ #create mtrx n*n*3 matrix for normal vector between two joints
+
+ return pred_skel, parent
+
+
+
+
diff --git a/ThirdParty/__init__.py b/ThirdParty/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ThirdParty/__pycache__/__init__.cpython-310.pyc b/ThirdParty/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a92dd289073f9cfd045f734e5dec8b9347fb84f
Binary files /dev/null and b/ThirdParty/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/__init__.py b/ThirdParty/eg3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ThirdParty/eg3d/__pycache__/__init__.cpython-310.pyc b/ThirdParty/eg3d/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7d8a4ddea086130cde36cd3a29574f216cc894d
Binary files /dev/null and b/ThirdParty/eg3d/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/calc_metrics.py b/ThirdParty/eg3d/calc_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..d401b22554e142a4146a0eb0fc952cc20742e3e7
--- /dev/null
+++ b/ThirdParty/eg3d/calc_metrics.py
@@ -0,0 +1,190 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Calculate quality metrics for previous training run or pretrained network pickle."""
+
+import os
+import click
+import json
+import tempfile
+import copy
+import torch
+
+import dnnlib
+import legacy
+from metrics import metric_main
+from metrics import metric_utils
+from torch_utils import training_stats
+from torch_utils import custom_ops
+from torch_utils import misc
+from torch_utils.ops import conv2d_gradfix
+
+#----------------------------------------------------------------------------
+
+def subprocess_fn(rank, args, temp_dir):
+ dnnlib.util.Logger(should_flush=True)
+
+ # Init torch.distributed.
+ if args.num_gpus > 1:
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
+ if os.name == 'nt':
+ init_method = 'file:///' + init_file.replace('\\', '/')
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
+ else:
+ init_method = f'file://{init_file}'
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
+
+ # Init torch_utils.
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
+ if rank != 0 or not args.verbose:
+ custom_ops.verbosity = 'none'
+
+ # Configure torch.
+ device = torch.device('cuda', rank)
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.allow_tf32 = False
+ conv2d_gradfix.enabled = True
+
+ # Print network summary.
+ G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
+ if rank == 0 and args.verbose:
+ z = torch.empty([1, G.z_dim], device=device)
+ c = torch.empty([1, G.c_dim], device=device)
+ misc.print_module_summary(G, [z, c])
+
+ # Calculate each metric.
+ for metric in args.metrics:
+ if rank == 0 and args.verbose:
+ print(f'Calculating {metric}...')
+ progress = metric_utils.ProgressMonitor(verbose=args.verbose)
+ result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
+ num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
+ if rank == 0:
+ metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
+ if rank == 0 and args.verbose:
+ print()
+
+ # Done.
+ if rank == 0 and args.verbose:
+ print('Exiting...')
+
+#----------------------------------------------------------------------------
+
+def parse_comma_separated_list(s):
+ if isinstance(s, list):
+ return s
+ if s is None or s.lower() == 'none' or s == '':
+ return []
+ return s.split(',')
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.pass_context
+@click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
+@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True)
+@click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]')
+@click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL')
+@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
+@click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
+
+def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
+ """Calculate quality metrics for previous training run or pretrained network pickle.
+
+ Examples:
+
+ \b
+ # Previous training run: look up options automatically, save result to JSONL file.
+ python calc_metrics.py --metrics=eqt50k_int,eqr50k \\
+ --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl
+
+ \b
+ # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
+ python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
+
+ \b
+ Recommended metrics:
+ fid50k_full Frechet inception distance against the full dataset.
+ kid50k_full Kernel inception distance against the full dataset.
+ pr50k3_full Precision and recall againt the full dataset.
+ ppl2_wend Perceptual path length in W, endpoints, full image.
+ eqt50k_int Equivariance w.r.t. integer translation (EQ-T).
+ eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac).
+ eqr50k Equivariance w.r.t. rotation (EQ-R).
+
+ \b
+ Legacy metrics:
+ fid50k Frechet inception distance against 50k real images.
+ kid50k Kernel inception distance against 50k real images.
+ pr50k3 Precision and recall against 50k real images.
+ is50k Inception score for CIFAR-10.
+ """
+ dnnlib.util.Logger(should_flush=True)
+
+ # Validate arguments.
+ args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
+ if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
+ ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
+ if not args.num_gpus >= 1:
+ ctx.fail('--gpus must be at least 1')
+
+ # Load network.
+ if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
+ ctx.fail('--network must point to a file or URL')
+ if args.verbose:
+ print(f'Loading network from "{network_pkl}"...')
+ with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
+ network_dict = legacy.load_network_pkl(f)
+ args.G = network_dict['G_ema'] # subclass of torch.nn.Module
+
+ # Initialize dataset options.
+ if data is not None:
+ args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
+ elif network_dict['training_set_kwargs'] is not None:
+ args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
+ else:
+ ctx.fail('Could not look up dataset options; please specify --data')
+
+ # Finalize dataset options.
+ args.dataset_kwargs.resolution = args.G.img_resolution
+ args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
+ if mirror is not None:
+ args.dataset_kwargs.xflip = mirror
+
+ # Print dataset options.
+ if args.verbose:
+ print('Dataset options:')
+ print(json.dumps(args.dataset_kwargs, indent=2))
+
+ # Locate run dir.
+ args.run_dir = None
+ if os.path.isfile(network_pkl):
+ pkl_dir = os.path.dirname(network_pkl)
+ if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
+ args.run_dir = pkl_dir
+
+ # Launch processes.
+ if args.verbose:
+ print('Launching processes...')
+ torch.multiprocessing.set_start_method('spawn')
+ with tempfile.TemporaryDirectory() as temp_dir:
+ if args.num_gpus == 1:
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
+ else:
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ calc_metrics() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/camera_utils.py b/ThirdParty/eg3d/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d4be88a575b4f43cce42f71222215e9b912d9f9
--- /dev/null
+++ b/ThirdParty/eg3d/camera_utils.py
@@ -0,0 +1,149 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""
+Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts.
+"""
+
+import math
+
+import torch
+import torch.nn as nn
+
+from training.volumetric_rendering import math_utils
+
+class GaussianCameraPoseSampler:
+ """
+ Samples pitch and yaw from a Gaussian distribution and returns a camera pose.
+ Camera is specified as looking at the origin.
+ If horizontal and vertical stddev (specified in radians) are zero, gives a
+ deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean.
+ The coordinate system is specified with y-up, z-forward, x-left.
+ Horizontal mean is the azimuthal angle (rotation around y axis) in radians,
+ vertical mean is the polar angle (angle from the y axis) in radians.
+ A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2.
+
+ Example:
+ For a camera pose looking at the origin with the camera at position [0, 0, 1]:
+ cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
+ h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
+ v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h
+ v = v / math.pi
+ phi = torch.arccos(1 - 2*v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device)
+
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
+
+ forward_vectors = math_utils.normalize_vecs(-camera_origins)
+ return create_cam2world_matrix(forward_vectors, camera_origins)
+
+
+class LookAtPoseSampler:
+ """
+ Same as GaussianCameraPoseSampler, except the
+ camera is specified as looking at 'lookat_position', a 3-vector.
+
+ Example:
+ For a camera pose looking at the origin with the camera at position [0, 0, 1]:
+ cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
+ h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
+ v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h
+ v = v / math.pi
+ phi = torch.arccos(1 - 2*v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device)
+
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
+
+ # forward_vectors = math_utils.normalize_vecs(-camera_origins)
+ forward_vectors = math_utils.normalize_vecs(lookat_position - camera_origins)
+ return create_cam2world_matrix(forward_vectors, camera_origins)
+
+class UniformCameraPoseSampler:
+ """
+ Same as GaussianCameraPoseSampler, except the
+ pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev.
+
+ Example:
+ For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians:
+
+ cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
+ h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean
+ v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h
+ v = v / math.pi
+ phi = torch.arccos(1 - 2*v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device)
+
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
+
+ forward_vectors = math_utils.normalize_vecs(-camera_origins)
+ return create_cam2world_matrix(forward_vectors, camera_origins)
+
+def create_cam2world_matrix(forward_vector, origin):
+ """
+ Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
+ Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll.
+ """
+
+ forward_vector = math_utils.normalize_vecs(forward_vector)
+ up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=origin.device).expand_as(forward_vector)
+
+ right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
+ up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1))
+
+ rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
+ rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1)
+
+ translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
+ translation_matrix[:, :3, 3] = origin
+ cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
+ assert(cam2world.shape[1:] == (4, 4))
+ return cam2world
+
+
+def FOV_to_intrinsics(fov_degrees, device='cpu'):
+ """
+ Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
+ Note the intrinsics are returned as normalized by image size, rather than in pixel units.
+ Assumes principal point is at image center.
+ """
+
+ focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414))
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
+ return intrinsics
\ No newline at end of file
diff --git a/ThirdParty/eg3d/dataset_tool.py b/ThirdParty/eg3d/dataset_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..a400f770fa477ef09adf4804235be4d67898765a
--- /dev/null
+++ b/ThirdParty/eg3d/dataset_tool.py
@@ -0,0 +1,458 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Tool for creating ZIP/PNG based datasets."""
+
+import functools
+import gzip
+import io
+import json
+import os
+import pickle
+import re
+import sys
+import tarfile
+import zipfile
+from pathlib import Path
+from typing import Callable, Optional, Tuple, Union
+
+import click
+import numpy as np
+import PIL.Image
+from tqdm import tqdm
+
+#----------------------------------------------------------------------------
+
+def error(msg):
+ print('Error: ' + msg)
+ sys.exit(1)
+
+#----------------------------------------------------------------------------
+
+def parse_tuple(s: str) -> Tuple[int, int]:
+ '''Parse a 'M,N' or 'MxN' integer tuple.
+
+ Example:
+ '4x2' returns (4,2)
+ '0,1' returns (0,1)
+ '''
+ if m := re.match(r'^(\d+)[x,](\d+)$', s):
+ return (int(m.group(1)), int(m.group(2)))
+ raise ValueError(f'cannot parse tuple {s}')
+
+#----------------------------------------------------------------------------
+
+def maybe_min(a: int, b: Optional[int]) -> int:
+ if b is not None:
+ return min(a, b)
+ return a
+
+#----------------------------------------------------------------------------
+
+def file_ext(name: Union[str, Path]) -> str:
+ return str(name).split('.')[-1]
+
+#----------------------------------------------------------------------------
+
+def is_image_ext(fname: Union[str, Path]) -> bool:
+ ext = file_ext(fname).lower()
+ return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
+
+#----------------------------------------------------------------------------
+
+def open_image_folder(source_dir, *, max_images: Optional[int]):
+ input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
+
+ # Load labels.
+ labels = {}
+ meta_fname = os.path.join(source_dir, 'dataset.json')
+ if os.path.isfile(meta_fname):
+ with open(meta_fname, 'r') as file:
+ labels = json.load(file)['labels']
+ if labels is not None:
+ labels = { x[0]: x[1] for x in labels }
+ else:
+ labels = {}
+
+ max_idx = maybe_min(len(input_images), max_images)
+
+ def iterate_images():
+ for idx, fname in enumerate(input_images):
+ arch_fname = os.path.relpath(fname, source_dir)
+ arch_fname = arch_fname.replace('\\', '/')
+ img = np.array(PIL.Image.open(fname))
+ yield dict(img=img, label=labels.get(arch_fname))
+ if idx >= max_idx-1:
+ break
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_image_zip(source, *, max_images: Optional[int]):
+ with zipfile.ZipFile(source, mode='r') as z:
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
+
+ # Load labels.
+ labels = {}
+ if 'dataset.json' in z.namelist():
+ with z.open('dataset.json', 'r') as file:
+ labels = json.load(file)['labels']
+ if labels is not None:
+ labels = { x[0]: x[1] for x in labels }
+ else:
+ labels = {}
+
+ max_idx = maybe_min(len(input_images), max_images)
+
+ def iterate_images():
+ with zipfile.ZipFile(source, mode='r') as z:
+ for idx, fname in enumerate(input_images):
+ with z.open(fname, 'r') as file:
+ img = PIL.Image.open(file) # type: ignore
+ img = np.array(img)
+ yield dict(img=img, label=labels.get(fname))
+ if idx >= max_idx-1:
+ break
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
+ import cv2 # pip install opencv-python # pylint: disable=import-error
+ import lmdb # pip install lmdb # pylint: disable=import-error
+
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
+ max_idx = maybe_min(txn.stat()['entries'], max_images)
+
+ def iterate_images():
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
+ for idx, (_key, value) in enumerate(txn.cursor()):
+ try:
+ try:
+ img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
+ if img is None:
+ raise IOError('cv2.imdecode failed')
+ img = img[:, :, ::-1] # BGR => RGB
+ except IOError:
+ img = np.array(PIL.Image.open(io.BytesIO(value)))
+ yield dict(img=img, label=None)
+ if idx >= max_idx-1:
+ break
+ except:
+ print(sys.exc_info()[1])
+
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_cifar10(tarball: str, *, max_images: Optional[int]):
+ images = []
+ labels = []
+
+ with tarfile.open(tarball, 'r:gz') as tar:
+ for batch in range(1, 6):
+ member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
+ with tar.extractfile(member) as file:
+ data = pickle.load(file, encoding='latin1')
+ images.append(data['data'].reshape(-1, 3, 32, 32))
+ labels.append(data['labels'])
+
+ images = np.concatenate(images)
+ labels = np.concatenate(labels)
+ images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
+ assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
+ assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
+ assert np.min(images) == 0 and np.max(images) == 255
+ assert np.min(labels) == 0 and np.max(labels) == 9
+
+ max_idx = maybe_min(len(images), max_images)
+
+ def iterate_images():
+ for idx, img in enumerate(images):
+ yield dict(img=img, label=int(labels[idx]))
+ if idx >= max_idx-1:
+ break
+
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def open_mnist(images_gz: str, *, max_images: Optional[int]):
+ labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
+ assert labels_gz != images_gz
+ images = []
+ labels = []
+
+ with gzip.open(images_gz, 'rb') as f:
+ images = np.frombuffer(f.read(), np.uint8, offset=16)
+ with gzip.open(labels_gz, 'rb') as f:
+ labels = np.frombuffer(f.read(), np.uint8, offset=8)
+
+ images = images.reshape(-1, 28, 28)
+ images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
+ assert np.min(images) == 0 and np.max(images) == 255
+ assert np.min(labels) == 0 and np.max(labels) == 9
+
+ max_idx = maybe_min(len(images), max_images)
+
+ def iterate_images():
+ for idx, img in enumerate(images):
+ yield dict(img=img, label=int(labels[idx]))
+ if idx >= max_idx-1:
+ break
+
+ return max_idx, iterate_images()
+
+#----------------------------------------------------------------------------
+
+def make_transform(
+ transform: Optional[str],
+ output_width: Optional[int],
+ output_height: Optional[int]
+) -> Callable[[np.ndarray], Optional[np.ndarray]]:
+ def scale(width, height, img):
+ w = img.shape[1]
+ h = img.shape[0]
+ if width == w and height == h:
+ return img
+ img = PIL.Image.fromarray(img)
+ ww = width if width is not None else w
+ hh = height if height is not None else h
+ img = img.resize((ww, hh), PIL.Image.LANCZOS)
+ return np.array(img)
+
+ def center_crop(width, height, img):
+ crop = np.min(img.shape[:2])
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
+ img = PIL.Image.fromarray(img, 'RGB')
+ img = img.resize((width, height), PIL.Image.LANCZOS)
+ return np.array(img)
+
+ def center_crop_wide(width, height, img):
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
+ if img.shape[1] < width or ch < height:
+ return None
+
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
+ img = PIL.Image.fromarray(img, 'RGB')
+ img = img.resize((width, height), PIL.Image.LANCZOS)
+ img = np.array(img)
+
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
+ return canvas
+
+ if transform is None:
+ return functools.partial(scale, output_width, output_height)
+ if transform == 'center-crop':
+ if (output_width is None) or (output_height is None):
+ error ('must specify --resolution=WxH when using ' + transform + 'transform')
+ return functools.partial(center_crop, output_width, output_height)
+ if transform == 'center-crop-wide':
+ if (output_width is None) or (output_height is None):
+ error ('must specify --resolution=WxH when using ' + transform + ' transform')
+ return functools.partial(center_crop_wide, output_width, output_height)
+ assert False, 'unknown transform'
+
+#----------------------------------------------------------------------------
+
+def open_dataset(source, *, max_images: Optional[int]):
+ if os.path.isdir(source):
+ if source.rstrip('/').endswith('_lmdb'):
+ return open_lmdb(source, max_images=max_images)
+ else:
+ return open_image_folder(source, max_images=max_images)
+ elif os.path.isfile(source):
+ if os.path.basename(source) == 'cifar-10-python.tar.gz':
+ return open_cifar10(source, max_images=max_images)
+ elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
+ return open_mnist(source, max_images=max_images)
+ elif file_ext(source) == 'zip':
+ return open_image_zip(source, max_images=max_images)
+ else:
+ assert False, 'unknown archive type'
+ else:
+ error(f'Missing input file or directory: {source}')
+
+#----------------------------------------------------------------------------
+
+def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
+ dest_ext = file_ext(dest)
+
+ if dest_ext == 'zip':
+ if os.path.dirname(dest) != '':
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
+ zf.writestr(fname, data)
+ return '', zip_write_bytes, zf.close
+ else:
+ # If the output folder already exists, check that is is
+ # empty.
+ #
+ # Note: creating the output directory is not strictly
+ # necessary as folder_write_bytes() also mkdirs, but it's better
+ # to give an error message earlier in case the dest folder
+ # somehow cannot be created.
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
+ error('--dest folder must be empty')
+ os.makedirs(dest, exist_ok=True)
+
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
+ with open(fname, 'wb') as fout:
+ if isinstance(data, str):
+ data = data.encode('utf8')
+ fout.write(data)
+ return dest, folder_write_bytes, lambda: None
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.pass_context
+@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
+@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
+@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
+@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
+@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
+def convert_dataset(
+ ctx: click.Context,
+ source: str,
+ dest: str,
+ max_images: Optional[int],
+ transform: Optional[str],
+ resolution: Optional[Tuple[int, int]]
+):
+ """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
+
+ The input dataset format is guessed from the --source argument:
+
+ \b
+ --source *_lmdb/ Load LSUN dataset
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
+ --source path/ Recursively load all images from path/
+ --source dataset.zip Recursively load all images from dataset.zip
+
+ Specifying the output format and path:
+
+ \b
+ --dest /path/to/dir Save output files under /path/to/dir
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
+
+ The output dataset format can be either an image folder or an uncompressed zip archive.
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
+ offer better training performance on network file systems.
+
+ Images within the dataset archive will be stored as uncompressed PNG.
+ Uncompressed PNGs can be efficiently decoded in the training loop.
+
+ Class labels are stored in a file called 'dataset.json' that is stored at the
+ dataset root folder. This file has the following structure:
+
+ \b
+ {
+ "labels": [
+ ["00000/img00000000.png",6],
+ ["00000/img00000001.png",9],
+ ... repeated for every image in the dataset
+ ["00049/img00049999.png",1]
+ ]
+ }
+
+ If the 'dataset.json' file cannot be found, the dataset is interpreted as
+ not containing class labels.
+
+ Image scale/crop and resolution requirements:
+
+ Output images must be square-shaped and they must all have the same power-of-two
+ dimensions.
+
+ To scale arbitrary input image size to a specific width and height, use the
+ --resolution option. Output resolution will be either the original
+ input resolution (if resolution was not specified) or the one specified with
+ --resolution option.
+
+ Use the --transform=center-crop or --transform=center-crop-wide options to apply a
+ center crop transform on the input image. These options should be used with the
+ --resolution option. For example:
+
+ \b
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
+ --transform=center-crop-wide --resolution=512x384
+ """
+
+ PIL.Image.init() # type: ignore
+
+ if dest == '':
+ ctx.fail('--dest output filename or directory must not be an empty string')
+
+ num_files, input_iter = open_dataset(source, max_images=max_images)
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
+
+ if resolution is None: resolution = (None, None)
+ transform_image = make_transform(transform, *resolution)
+
+ dataset_attrs = None
+
+ labels = []
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
+ idx_str = f'{idx:08d}'
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
+
+ # Apply crop and resize.
+ img = transform_image(image['img'])
+
+ # Transform may drop images.
+ if img is None:
+ continue
+
+ # Error check to require uniform image attributes across
+ # the whole dataset.
+ channels = img.shape[2] if img.ndim == 3 else 1
+ cur_image_attrs = {
+ 'width': img.shape[1],
+ 'height': img.shape[0],
+ 'channels': channels
+ }
+ if dataset_attrs is None:
+ dataset_attrs = cur_image_attrs
+ width = dataset_attrs['width']
+ height = dataset_attrs['height']
+ if width != height:
+ error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
+ if dataset_attrs['channels'] not in [1, 3, 4]:
+ error('Input images must be stored as RGB or grayscale')
+ if width != 2 ** int(np.floor(np.log2(width))):
+ error('Image width/height after scale and crop are required to be power-of-two')
+ elif dataset_attrs != cur_image_attrs:
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] # pylint: disable=unsubscriptable-object
+ error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
+
+ # Save the image as an uncompressed PNG.
+ img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB', 4: 'RGBA'}[channels])
+ if channels == 4: img = img.convert('RGB')
+ image_bits = io.BytesIO()
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
+ labels.append([archive_fname, image['label']] if image['label'] is not None else None)
+
+ metadata = {
+ 'labels': labels if all(x is not None for x in labels) else None
+ }
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
+ close_dest()
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ convert_dataset() # pylint: disable=no-value-for-parameter
diff --git a/ThirdParty/eg3d/dnnlib/__init__.py b/ThirdParty/eg3d/dnnlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd91ed142e955581e83948455fb71cd837215f61
--- /dev/null
+++ b/ThirdParty/eg3d/dnnlib/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+from .util import EasyDict, make_cache_dir_path
diff --git a/ThirdParty/eg3d/dnnlib/__pycache__/__init__.cpython-310.pyc b/ThirdParty/eg3d/dnnlib/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b69694f8dff325adb4e7176d784e2e68dfd451c4
Binary files /dev/null and b/ThirdParty/eg3d/dnnlib/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/dnnlib/__pycache__/util.cpython-310.pyc b/ThirdParty/eg3d/dnnlib/__pycache__/util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09c2f46daed7598d146e9060f7ac6f251451b7da
Binary files /dev/null and b/ThirdParty/eg3d/dnnlib/__pycache__/util.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/dnnlib/util.py b/ThirdParty/eg3d/dnnlib/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b67c4e312cd1b847ca21fd3b929802a57e6f6d
--- /dev/null
+++ b/ThirdParty/eg3d/dnnlib/util.py
@@ -0,0 +1,493 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Miscellaneous utility classes and functions."""
+
+import ctypes
+import fnmatch
+import importlib
+import inspect
+import numpy as np
+import os
+import shutil
+import sys
+import types
+import io
+import pickle
+import re
+import requests
+import html
+import hashlib
+import glob
+import tempfile
+import urllib
+import urllib.request
+import uuid
+
+from distutils.util import strtobool
+from typing import Any, List, Tuple, Union
+
+
+# Util classes
+# ------------------------------------------------------------------------------------------
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+
+class Logger(object):
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
+
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
+ self.file = None
+
+ if file_name is not None:
+ self.file = open(file_name, file_mode)
+
+ self.should_flush = should_flush
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ sys.stdout = self
+ sys.stderr = self
+
+ def __enter__(self) -> "Logger":
+ return self
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.close()
+
+ def write(self, text: Union[str, bytes]) -> None:
+ """Write text to stdout (and a file) and optionally flush."""
+ if isinstance(text, bytes):
+ text = text.decode()
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
+ return
+
+ if self.file is not None:
+ self.file.write(text)
+
+ self.stdout.write(text)
+
+ if self.should_flush:
+ self.flush()
+
+ def flush(self) -> None:
+ """Flush written text to both stdout and a file, if open."""
+ if self.file is not None:
+ self.file.flush()
+
+ self.stdout.flush()
+
+ def close(self) -> None:
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
+ self.flush()
+
+ # if using multiple loggers, prevent closing in wrong order
+ if sys.stdout is self:
+ sys.stdout = self.stdout
+ if sys.stderr is self:
+ sys.stderr = self.stderr
+
+ if self.file is not None:
+ self.file.close()
+ self.file = None
+
+
+# Cache directories
+# ------------------------------------------------------------------------------------------
+
+_dnnlib_cache_dir = None
+
+def set_cache_dir(path: str) -> None:
+ global _dnnlib_cache_dir
+ _dnnlib_cache_dir = path
+
+def make_cache_dir_path(*paths: str) -> str:
+ if _dnnlib_cache_dir is not None:
+ return os.path.join(_dnnlib_cache_dir, *paths)
+ if 'DNNLIB_CACHE_DIR' in os.environ:
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
+ if 'HOME' in os.environ:
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
+ if 'USERPROFILE' in os.environ:
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
+
+# Small util functions
+# ------------------------------------------------------------------------------------------
+
+
+def format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
+
+
+def format_time_brief(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
+ else:
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
+
+
+def ask_yes_no(question: str) -> bool:
+ """Ask the user the question until the user inputs a valid answer."""
+ while True:
+ try:
+ print("{0} [y/n]".format(question))
+ return strtobool(input().lower())
+ except ValueError:
+ pass
+
+
+def tuple_product(t: Tuple) -> Any:
+ """Calculate the product of the tuple elements."""
+ result = 1
+
+ for v in t:
+ result *= v
+
+ return result
+
+
+_str_to_ctype = {
+ "uint8": ctypes.c_ubyte,
+ "uint16": ctypes.c_uint16,
+ "uint32": ctypes.c_uint32,
+ "uint64": ctypes.c_uint64,
+ "int8": ctypes.c_byte,
+ "int16": ctypes.c_int16,
+ "int32": ctypes.c_int32,
+ "int64": ctypes.c_int64,
+ "float32": ctypes.c_float,
+ "float64": ctypes.c_double
+}
+
+
+def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
+ type_str = None
+
+ if isinstance(type_obj, str):
+ type_str = type_obj
+ elif hasattr(type_obj, "__name__"):
+ type_str = type_obj.__name__
+ elif hasattr(type_obj, "name"):
+ type_str = type_obj.name
+ else:
+ raise RuntimeError("Cannot infer type name from input")
+
+ assert type_str in _str_to_ctype.keys()
+
+ my_dtype = np.dtype(type_str)
+ my_ctype = _str_to_ctype[type_str]
+
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
+
+ return my_dtype, my_ctype
+
+
+def is_pickleable(obj: Any) -> bool:
+ try:
+ with io.BytesIO() as stream:
+ pickle.dump(obj, stream)
+ return True
+ except:
+ return False
+
+
+# Functionality to import modules/objects by name, and call functions by name
+# ------------------------------------------------------------------------------------------
+
+def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
+ """Searches for the underlying module behind the name to some python object.
+ Returns the module and the object name (original name with module part removed)."""
+
+ # allow convenience shorthands, substitute them by full names
+ obj_name = re.sub("^np.", "numpy.", obj_name)
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
+
+ # list alternatives for (module_name, local_obj_name)
+ parts = obj_name.split(".")
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
+
+ # try each alternative in turn
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ return module, local_obj_name
+ except:
+ pass
+
+ # maybe some of the modules themselves contain errors?
+ for module_name, _local_obj_name in name_pairs:
+ try:
+ importlib.import_module(module_name) # may raise ImportError
+ except ImportError:
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
+ raise
+
+ # maybe the requested attribute is missing?
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ except ImportError:
+ pass
+
+ # we are out of luck, but we have no idea why
+ raise ImportError(obj_name)
+
+
+def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
+ """Traverses the object name and returns the last (rightmost) python object."""
+ if obj_name == '':
+ return module
+ obj = module
+ for part in obj_name.split("."):
+ obj = getattr(obj, part)
+ return obj
+
+
+def get_obj_by_name(name: str) -> Any:
+ """Finds the python object with the given name."""
+ module, obj_name = get_module_from_obj_name(name)
+ return get_obj_from_module(module, obj_name)
+
+
+def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
+ """Finds the python object with the given name and calls it as a function."""
+ assert func_name is not None
+ func_obj = get_obj_by_name(func_name)
+ assert callable(func_obj)
+ return func_obj(*args, **kwargs)
+
+
+def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
+ """Finds the python class with the given name and constructs it with the given arguments."""
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
+
+
+def get_module_dir_by_obj_name(obj_name: str) -> str:
+ """Get the directory path of the module containing the given object name."""
+ module, _ = get_module_from_obj_name(obj_name)
+ return os.path.dirname(inspect.getfile(module))
+
+
+def is_top_level_function(obj: Any) -> bool:
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
+
+
+def get_top_level_function_name(obj: Any) -> str:
+ """Return the fully-qualified name of a top-level function."""
+ assert is_top_level_function(obj)
+ module = obj.__module__
+ if module == '__main__':
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
+ return module + "." + obj.__name__
+
+
+# File system helpers
+# ------------------------------------------------------------------------------------------
+
+def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
+ """List all files recursively in a given directory while ignoring given file and directory names.
+ Returns list of tuples containing both absolute and relative paths."""
+ assert os.path.isdir(dir_path)
+ base_name = os.path.basename(os.path.normpath(dir_path))
+
+ if ignores is None:
+ ignores = []
+
+ result = []
+
+ for root, dirs, files in os.walk(dir_path, topdown=True):
+ for ignore_ in ignores:
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
+
+ # dirs need to be edited in-place
+ for d in dirs_to_remove:
+ dirs.remove(d)
+
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
+
+ absolute_paths = [os.path.join(root, f) for f in files]
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
+
+ if add_base_to_relative:
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
+
+ assert len(absolute_paths) == len(relative_paths)
+ result += zip(absolute_paths, relative_paths)
+
+ return result
+
+
+def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
+ """Takes in a list of tuples of (src, dst) paths and copies files.
+ Will create all necessary directories."""
+ for file in files:
+ target_dir_name = os.path.dirname(file[1])
+
+ # will create all intermediate-level directories
+ if not os.path.exists(target_dir_name):
+ os.makedirs(target_dir_name)
+
+ shutil.copyfile(file[0], file[1])
+
+
+# URL helpers
+# ------------------------------------------------------------------------------------------
+
+def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
+ """Determine whether the given object is a valid URL string."""
+ if not isinstance(obj, str) or not "://" in obj:
+ return False
+ if allow_file_urls and obj.startswith('file://'):
+ return True
+ try:
+ res = requests.compat.urlparse(obj)
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ except:
+ return False
+ return True
+
+
+def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+ assert not (return_filename and (not cache))
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ assert is_url(url)
+
+ # Lookup from cache.
+ if cache_dir is None:
+ cache_dir = make_cache_dir_path('downloads')
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+ if cache:
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+ if len(cache_files) == 1:
+ filename = cache_files[0]
+ return filename if return_filename else open(filename, "rb")
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError("Google Drive download quota exceeded -- please try again later")
+
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Save to cache.
+ if cache:
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(temp_file, "wb") as f:
+ f.write(url_data)
+ os.replace(temp_file, cache_file) # atomic
+ if return_filename:
+ return cache_file
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
diff --git a/ThirdParty/eg3d/environment.yml b/ThirdParty/eg3d/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..082bcaf51b257b8dfe6148fac2fcac263fee3f15
--- /dev/null
+++ b/ThirdParty/eg3d/environment.yml
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+name: eg3d
+channels:
+ - pytorch
+ - nvidia
+dependencies:
+ - python >= 3.8
+ - pip
+ - numpy>=1.20
+ - click>=8.0
+ - pillow=8.3.1
+ - scipy=1.7.1
+ - pytorch=1.11.0
+ - cudatoolkit=11.1
+ - requests=2.26.0
+ - tqdm=4.62.2
+ - ninja=1.10.2
+ - matplotlib=3.4.2
+ - imageio=2.9.0
+ - pip:
+ - imgui==1.3.0
+ - glfw==2.2.0
+ - pyopengl==3.1.5
+ - imageio-ffmpeg==0.4.3
+ - pyspng
+ - psutil
+ - mrcfile
+ - tensorboard
\ No newline at end of file
diff --git a/ThirdParty/eg3d/gen_samples.py b/ThirdParty/eg3d/gen_samples.py
new file mode 100644
index 0000000000000000000000000000000000000000..fab4a22cc6f8e557542a0bb26ae0e8b8862c732c
--- /dev/null
+++ b/ThirdParty/eg3d/gen_samples.py
@@ -0,0 +1,230 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Generate images and shapes using pretrained network pickle."""
+
+import os
+import re
+from typing import List, Optional, Tuple, Union
+
+import click
+import dnnlib
+import numpy as np
+import PIL.Image
+import torch
+from tqdm import tqdm
+import mrcfile
+
+
+import legacy
+from camera_utils import LookAtPoseSampler, FOV_to_intrinsics
+from torch_utils import misc
+from training.triplane import TriPlaneGenerator
+
+
+#----------------------------------------------------------------------------
+
+def parse_range(s: Union[str, List]) -> List[int]:
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
+
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
+ '''
+ if isinstance(s, list): return s
+ ranges = []
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ for p in s.split(','):
+ if m := range_re.match(p):
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
+ else:
+ ranges.append(int(p))
+ return ranges
+
+#----------------------------------------------------------------------------
+
+def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
+ '''Parse a floating point 2-vector of syntax 'a,b'.
+
+ Example:
+ '0,1' returns (0,1)
+ '''
+ if isinstance(s, tuple): return s
+ parts = s.split(',')
+ if len(parts) == 2:
+ return (float(parts[0]), float(parts[1]))
+ raise ValueError(f'cannot parse 2-vector {s}')
+
+#----------------------------------------------------------------------------
+
+def make_transform(translate: Tuple[float,float], angle: float):
+ m = np.eye(3)
+ s = np.sin(angle/360.0*np.pi*2)
+ c = np.cos(angle/360.0*np.pi*2)
+ m[0][0] = c
+ m[0][1] = s
+ m[0][2] = translate[0]
+ m[1][0] = -s
+ m[1][1] = c
+ m[1][2] = translate[1]
+ return m
+
+#----------------------------------------------------------------------------
+
+def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0):
+ # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
+ voxel_origin = np.array(voxel_origin) - cube_length/2
+ voxel_size = cube_length / (N - 1)
+
+ overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
+ samples = torch.zeros(N ** 3, 3)
+
+ # transform first 3 columns
+ # to be the x, y, z index
+ samples[:, 2] = overall_index % N
+ samples[:, 1] = (overall_index.float() / N) % N
+ samples[:, 0] = ((overall_index.float() / N) / N) % N
+
+ # transform first 3 columns
+ # to be the x, y, z coordinate
+ samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
+ samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
+ samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
+
+ num_samples = N ** 3
+
+ return samples.unsqueeze(0), voxel_origin, voxel_size
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True)
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
+@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
+@click.option('--shapes', help='Export shapes as .mrc files viewable in ChimeraX', type=bool, required=False, metavar='BOOL', default=False, show_default=True)
+@click.option('--shape-res', help='', type=int, required=False, metavar='int', default=512, show_default=True)
+@click.option('--fov-deg', help='Field of View of camera in degrees', type=int, required=False, metavar='float', default=18.837, show_default=True)
+@click.option('--shape-format', help='Shape Format', type=click.Choice(['.mrc', '.ply']), default='.mrc')
+@click.option('--reload_modules', help='Overload persistent modules?', type=bool, required=False, metavar='BOOL', default=False, show_default=True)
+def generate_images(
+ network_pkl: str,
+ seeds: List[int],
+ truncation_psi: float,
+ truncation_cutoff: int,
+ outdir: str,
+ shapes: bool,
+ shape_res: int,
+ fov_deg: float,
+ shape_format: str,
+ class_idx: Optional[int],
+ reload_modules: bool,
+):
+ """Generate images using pretrained network pickle.
+
+ Examples:
+
+ \b
+ # Generate an image using pre-trained FFHQ model.
+ python gen_samples.py --outdir=output --trunc=0.7 --seeds=0-5 --shapes=True\\
+ --network=ffhq-rebalanced-128.pkl
+ """
+
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda')
+ with dnnlib.util.open_url(network_pkl) as f:
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
+
+ # Specify reload_modules=True if you want code modifications to take effect; otherwise uses pickled code
+ if reload_modules:
+ print("Reloading Modules!")
+ G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device)
+ misc.copy_params_and_buffers(G, G_new, require_all=True)
+ G_new.neural_rendering_resolution = G.neural_rendering_resolution
+ G_new.rendering_kwargs = G.rendering_kwargs
+ G = G_new
+
+ os.makedirs(outdir, exist_ok=True)
+
+ cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor([0, 0, 0.2], device=device), radius=2.7, device=device)
+ intrinsics = FOV_to_intrinsics(fov_deg, device=device)
+
+ # Generate images.
+ for seed_idx, seed in enumerate(seeds):
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
+
+ imgs = []
+ angle_p = -0.2
+ for angle_y, angle_p in [(.4, angle_p), (0, angle_p), (-.4, angle_p)]:
+ cam_pivot = torch.tensor(G.rendering_kwargs.get('avg_camera_pivot', [0, 0, 0]), device=device)
+ cam_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7)
+ cam2world_pose = LookAtPoseSampler.sample(np.pi/2 + angle_y, np.pi/2 + angle_p, cam_pivot, radius=cam_radius, device=device)
+ conditioning_cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, cam_pivot, radius=cam_radius, device=device)
+ camera_params = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
+ conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
+
+ ws = G.mapping(z, conditioning_params, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
+ img = G.synthesis(ws, camera_params)['image']
+
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ imgs.append(img)
+
+ img = torch.cat(imgs, dim=2)
+
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
+
+ if shapes:
+ # extract a shape.mrc with marching cubes. You can view the .mrc file using ChimeraX from UCSF.
+ max_batch=1000000
+
+ samples, voxel_origin, voxel_size = create_samples(N=shape_res, voxel_origin=[0, 0, 0], cube_length=G.rendering_kwargs['box_warp'] * 1)#.reshape(1, -1, 3)
+ samples = samples.to(z.device)
+ sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1), device=z.device)
+ transformed_ray_directions_expanded = torch.zeros((samples.shape[0], max_batch, 3), device=z.device)
+ transformed_ray_directions_expanded[..., -1] = -1
+
+ head = 0
+ with tqdm(total = samples.shape[1]) as pbar:
+ with torch.no_grad():
+ while head < samples.shape[1]:
+ torch.manual_seed(0)
+ sigma = G.sample(samples[:, head:head+max_batch], transformed_ray_directions_expanded[:, :samples.shape[1]-head], z, conditioning_params, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, noise_mode='const')['sigma']
+ sigmas[:, head:head+max_batch] = sigma
+ head += max_batch
+ pbar.update(max_batch)
+
+ sigmas = sigmas.reshape((shape_res, shape_res, shape_res)).cpu().numpy()
+ sigmas = np.flip(sigmas, 0)
+
+ # Trim the border of the extracted cube
+ pad = int(30 * shape_res / 256)
+ pad_value = -1000
+ sigmas[:pad] = pad_value
+ sigmas[-pad:] = pad_value
+ sigmas[:, :pad] = pad_value
+ sigmas[:, -pad:] = pad_value
+ sigmas[:, :, :pad] = pad_value
+ sigmas[:, :, -pad:] = pad_value
+
+ if shape_format == '.ply':
+ from shape_utils import convert_sdf_samples_to_ply
+ convert_sdf_samples_to_ply(np.transpose(sigmas, (2, 1, 0)), [0, 0, 0], 1, os.path.join(outdir, f'seed{seed:04d}.ply'), level=10)
+ elif shape_format == '.mrc': # output mrc
+ with mrcfile.new_mmap(os.path.join(outdir, f'seed{seed:04d}.mrc'), overwrite=True, shape=sigmas.shape, mrc_mode=2) as mrc:
+ mrc.data[:] = sigmas
+
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ generate_images() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/gen_videos.py b/ThirdParty/eg3d/gen_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..de03d44c66f89999590979932792f6770c51fe69
--- /dev/null
+++ b/ThirdParty/eg3d/gen_videos.py
@@ -0,0 +1,331 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Generate lerp videos using pretrained network pickle."""
+
+import os
+import re
+from typing import List, Optional, Tuple, Union
+
+import click
+import dnnlib
+import imageio
+import numpy as np
+import scipy.interpolate
+import torch
+from tqdm import tqdm
+import mrcfile
+
+import legacy
+
+from camera_utils import LookAtPoseSampler
+from torch_utils import misc
+#----------------------------------------------------------------------------
+
+def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
+ batch_size, channels, img_h, img_w = img.shape
+ if grid_w is None:
+ grid_w = batch_size // grid_h
+ assert batch_size == grid_w * grid_h
+ if float_to_uint8:
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
+ img = img.permute(2, 0, 3, 1, 4)
+ img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
+ if chw_to_hwc:
+ img = img.permute(1, 2, 0)
+ if to_numpy:
+ img = img.cpu().numpy()
+ return img
+
+def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0):
+ # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
+ voxel_origin = np.array(voxel_origin) - cube_length/2
+ voxel_size = cube_length / (N - 1)
+
+ overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
+ samples = torch.zeros(N ** 3, 3)
+
+ # transform first 3 columns
+ # to be the x, y, z index
+ samples[:, 2] = overall_index % N
+ samples[:, 1] = (overall_index.float() / N) % N
+ samples[:, 0] = ((overall_index.float() / N) / N) % N
+
+ # transform first 3 columns
+ # to be the x, y, z coordinate
+ samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
+ samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
+ samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
+
+ num_samples = N ** 3
+
+ return samples.unsqueeze(0), voxel_origin, voxel_size
+
+#----------------------------------------------------------------------------
+
+def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, truncation_cutoff=14, cfg='FFHQ', image_mode='image', gen_shapes=False, device=torch.device('cuda'), **video_kwargs):
+ grid_w = grid_dims[0]
+ grid_h = grid_dims[1]
+
+ if num_keyframes is None:
+ if len(seeds) % (grid_w*grid_h) != 0:
+ raise ValueError('Number of input seeds must be divisible by grid W*H')
+ num_keyframes = len(seeds) // (grid_w*grid_h)
+
+ all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
+ for idx in range(num_keyframes*grid_h*grid_w):
+ all_seeds[idx] = seeds[idx % len(seeds)]
+
+ if shuffle_seed is not None:
+ rng = np.random.RandomState(seed=shuffle_seed)
+ rng.shuffle(all_seeds)
+
+ camera_lookat_point = torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device)
+ zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
+ cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, camera_lookat_point, radius=G.rendering_kwargs['avg_camera_radius'], device=device)
+ focal_length = 4.2647 if cfg != 'Shapenet' else 1.7074 # shapenet has higher FOV
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
+ c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
+ c = c.repeat(len(zs), 1)
+ ws = G.mapping(z=zs, c=c, truncation_psi=psi, truncation_cutoff=truncation_cutoff)
+ _ = G.synthesis(ws[:1], c[:1]) # warm up
+ ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])
+
+ # Interpolation.
+ grid = []
+ for yi in range(grid_h):
+ row = []
+ for xi in range(grid_w):
+ x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
+ y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
+ interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
+ row.append(interp)
+ grid.append(row)
+
+ # Render video.
+ max_batch = 10000000
+ voxel_resolution = 512
+ video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs)
+
+ if gen_shapes:
+ outdir = 'interpolation_{}_{}/'.format(all_seeds[0], all_seeds[1])
+ os.makedirs(outdir, exist_ok=True)
+ all_poses = []
+ for frame_idx in tqdm(range(num_keyframes * w_frames)):
+ imgs = []
+ for yi in range(grid_h):
+ for xi in range(grid_w):
+ pitch_range = 0.25
+ yaw_range = 0.35
+ cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / (num_keyframes * w_frames)),
+ 3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / (num_keyframes * w_frames)),
+ camera_lookat_point, radius=G.rendering_kwargs['avg_camera_radius'], device=device)
+ all_poses.append(cam2world_pose.squeeze().cpu().numpy())
+ focal_length = 4.2647 if cfg != 'Shapenet' else 1.7074 # shapenet has higher FOV
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
+ c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
+
+ interp = grid[yi][xi]
+ w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
+
+ entangle = 'camera'
+ if entangle == 'conditioning':
+ c_forward = torch.cat([LookAtPoseSampler.sample(3.14/2,
+ 3.14/2,
+ camera_lookat_point,
+ radius=G.rendering_kwargs['avg_camera_radius'], device=device).reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
+ w_c = G.mapping(z=zs[0:1], c=c[0:1], truncation_psi=psi, truncation_cutoff=truncation_cutoff)
+ img = G.synthesis(ws=w_c, c=c_forward, noise_mode='const')[image_mode][0]
+ elif entangle == 'camera':
+ img = G.synthesis(ws=w.unsqueeze(0), c=c[0:1], noise_mode='const')[image_mode][0]
+ elif entangle == 'both':
+ w_c = G.mapping(z=zs[0:1], c=c[0:1], truncation_psi=psi, truncation_cutoff=truncation_cutoff)
+ img = G.synthesis(ws=w_c, c=c[0:1], noise_mode='const')[image_mode][0]
+
+ if image_mode == 'image_depth':
+ img = -img
+ img = (img - img.min()) / (img.max() - img.min()) * 2 - 1
+
+ imgs.append(img)
+
+ if gen_shapes:
+ # generate shapes
+ print('Generating shape for frame %d / %d ...' % (frame_idx, num_keyframes * w_frames))
+
+ samples, voxel_origin, voxel_size = create_samples(N=voxel_resolution, voxel_origin=[0, 0, 0], cube_length=G.rendering_kwargs['box_warp'])
+ samples = samples.to(device)
+ sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1), device=device)
+ transformed_ray_directions_expanded = torch.zeros((samples.shape[0], max_batch, 3), device=device)
+ transformed_ray_directions_expanded[..., -1] = -1
+
+ head = 0
+ with tqdm(total = samples.shape[1]) as pbar:
+ with torch.no_grad():
+ while head < samples.shape[1]:
+ torch.manual_seed(0)
+ sigma = G.sample_mixed(samples[:, head:head+max_batch], transformed_ray_directions_expanded[:, :samples.shape[1]-head], w.unsqueeze(0), truncation_psi=psi, noise_mode='const')['sigma']
+ sigmas[:, head:head+max_batch] = sigma
+ head += max_batch
+ pbar.update(max_batch)
+
+ sigmas = sigmas.reshape((voxel_resolution, voxel_resolution, voxel_resolution)).cpu().numpy()
+ sigmas = np.flip(sigmas, 0)
+
+ pad = int(30 * voxel_resolution / 256)
+ pad_top = int(38 * voxel_resolution / 256)
+ sigmas[:pad] = 0
+ sigmas[-pad:] = 0
+ sigmas[:, :pad] = 0
+ sigmas[:, -pad_top:] = 0
+ sigmas[:, :, :pad] = 0
+ sigmas[:, :, -pad:] = 0
+
+ output_ply = True
+ if output_ply:
+ from shape_utils import convert_sdf_samples_to_ply
+ convert_sdf_samples_to_ply(np.transpose(sigmas, (2, 1, 0)), [0, 0, 0], 1, os.path.join(outdir, f'{frame_idx:04d}_shape.ply'), level=10)
+ else: # output mrc
+ with mrcfile.new_mmap(outdir + f'{frame_idx:04d}_shape.mrc', overwrite=True, shape=sigmas.shape, mrc_mode=2) as mrc:
+ mrc.data[:] = sigmas
+
+ video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
+ video_out.close()
+ all_poses = np.stack(all_poses)
+
+ if gen_shapes:
+ print(all_poses.shape)
+ with open(mp4.replace('.mp4', '_trajectory.npy'), 'wb') as f:
+ np.save(f, all_poses)
+
+#----------------------------------------------------------------------------
+
+def parse_range(s: Union[str, List[int]]) -> List[int]:
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
+
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
+ '''
+ if isinstance(s, list): return s
+ ranges = []
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ for p in s.split(','):
+ if m := range_re.match(p):
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
+ else:
+ ranges.append(int(p))
+ return ranges
+
+#----------------------------------------------------------------------------
+
+def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]:
+ '''Parse a 'M,N' or 'MxN' integer tuple.
+
+ Example:
+ '4x2' returns (4,2)
+ '0,1' returns (0,1)
+ '''
+ if isinstance(s, tuple): return s
+ if m := re.match(r'^(\d+)[x,](\d+)$', s):
+ return (int(m.group(1)), int(m.group(2)))
+ raise ValueError(f'cannot parse tuple {s}')
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--seeds', type=parse_range, help='List of random seeds', required=True)
+@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None)
+@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1))
+@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None)
+@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True)
+@click.option('--outdir', help='Output directory', type=str, required=True, metavar='DIR')
+@click.option('--reload_modules', help='Overload persistent modules?', type=bool, required=False, metavar='BOOL', default=False, show_default=True)
+@click.option('--cfg', help='Config', type=click.Choice(['FFHQ', 'AFHQ', 'Shapenet']), required=False, metavar='STR', default='FFHQ', show_default=True)
+@click.option('--image_mode', help='Image mode', type=click.Choice(['image', 'image_depth', 'image_raw']), required=False, metavar='STR', default='image', show_default=True)
+@click.option('--sample_mult', 'sampling_multiplier', type=float, help='Multiplier for depth sampling in volume rendering', default=2, show_default=True)
+@click.option('--nrr', type=int, help='Neural rendering resolution override', default=None, show_default=True)
+@click.option('--shapes', type=bool, help='Gen shapes for shape interpolation', default=False, show_default=True)
+@click.option('--interpolate', type=bool, help='Interpolate between seeds', default=True, show_default=True)
+
+def generate_images(
+ network_pkl: str,
+ seeds: List[int],
+ shuffle_seed: Optional[int],
+ truncation_psi: float,
+ truncation_cutoff: int,
+ grid: Tuple[int,int],
+ num_keyframes: Optional[int],
+ w_frames: int,
+ outdir: str,
+ reload_modules: bool,
+ cfg: str,
+ image_mode: str,
+ sampling_multiplier: float,
+ nrr: Optional[int],
+ shapes: bool,
+ interpolate: bool,
+):
+ """Render a latent vector interpolation video.
+
+ Examples:
+
+ \b
+ # Render a 4x2 grid of interpolations for seeds 0 through 31.
+ python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
+
+ Animation length and seed keyframes:
+
+ The animation length is either determined based on the --seeds value or explicitly
+ specified using the --num-keyframes option.
+
+ When num keyframes is specified with --num-keyframes, the output video length
+ will be 'num_keyframes*w_frames' frames.
+
+ If --num-keyframes is not specified, the number of seeds given with
+ --seeds must be divisible by grid size W*H (--grid). In this case the
+ output video length will be '# seeds/(w*h)*w_frames' frames.
+ """
+
+ if not os.path.exists(outdir):
+ os.makedirs(outdir, exist_ok=True)
+
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda')
+ with dnnlib.util.open_url(network_pkl) as f:
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
+
+
+ G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * sampling_multiplier)
+ G.rendering_kwargs['depth_resolution_importance'] = int(G.rendering_kwargs['depth_resolution_importance'] * sampling_multiplier)
+ if nrr is not None: G.neural_rendering_resolution = nrr
+
+ if truncation_cutoff == 0:
+ truncation_psi = 1.0 # truncation cutoff of 0 means no truncation anyways
+ if truncation_psi == 1.0:
+ truncation_cutoff = 14 # no truncation so doesn't matter where we cutoff
+
+ if interpolate:
+ output = os.path.join(outdir, 'interpolation.mp4')
+ gen_interp_video(G=G, mp4=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, truncation_cutoff=truncation_cutoff, cfg=cfg, image_mode=image_mode, gen_shapes=shapes)
+ else:
+ for seed in seeds:
+ output = os.path.join(outdir, f'{seed}.mp4')
+ seeds_ = [seed]
+ gen_interp_video(G=G, mp4=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds_, shuffle_seed=shuffle_seed, psi=truncation_psi, truncation_cutoff=truncation_cutoff, cfg=cfg, image_mode=image_mode)
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ generate_images() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/gui_utils/__init__.py b/ThirdParty/eg3d/gui_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/ThirdParty/eg3d/gui_utils/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/ThirdParty/eg3d/gui_utils/gl_utils.py b/ThirdParty/eg3d/gui_utils/gl_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1312f027c23bbb80eb489bba7a0f9014d95ac5b0
--- /dev/null
+++ b/ThirdParty/eg3d/gui_utils/gl_utils.py
@@ -0,0 +1,376 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import os
+import functools
+import contextlib
+import numpy as np
+import OpenGL.GL as gl
+import OpenGL.GL.ARB.texture_float
+import dnnlib
+
+#----------------------------------------------------------------------------
+
+def init_egl():
+ assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL.
+ import OpenGL.EGL as egl
+ import ctypes
+
+ # Initialize EGL.
+ display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY)
+ assert display != egl.EGL_NO_DISPLAY
+ major = ctypes.c_int32()
+ minor = ctypes.c_int32()
+ ok = egl.eglInitialize(display, major, minor)
+ assert ok
+ assert major.value * 10 + minor.value >= 14
+
+ # Choose config.
+ config_attribs = [
+ egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT,
+ egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT,
+ egl.EGL_NONE
+ ]
+ configs = (ctypes.c_int32 * 1)()
+ num_configs = ctypes.c_int32()
+ ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs)
+ assert ok
+ assert num_configs.value == 1
+ config = configs[0]
+
+ # Create dummy pbuffer surface.
+ surface_attribs = [
+ egl.EGL_WIDTH, 1,
+ egl.EGL_HEIGHT, 1,
+ egl.EGL_NONE
+ ]
+ surface = egl.eglCreatePbufferSurface(display, config, surface_attribs)
+ assert surface != egl.EGL_NO_SURFACE
+
+ # Setup GL context.
+ ok = egl.eglBindAPI(egl.EGL_OPENGL_API)
+ assert ok
+ context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None)
+ assert context != egl.EGL_NO_CONTEXT
+ ok = egl.eglMakeCurrent(display, surface, surface, context)
+ assert ok
+
+#----------------------------------------------------------------------------
+
+_texture_formats = {
+ ('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8),
+ ('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8),
+ ('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8),
+ ('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8),
+ ('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB),
+ ('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB),
+ ('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F),
+ ('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F),
+}
+
+def get_texture_format(dtype, channels):
+ return _texture_formats[(np.dtype(dtype).name, int(channels))]
+
+#----------------------------------------------------------------------------
+
+def prepare_texture_data(image):
+ image = np.asarray(image)
+ if image.ndim == 2:
+ image = image[:, :, np.newaxis]
+ if image.dtype.name == 'float64':
+ image = image.astype('float32')
+ return image
+
+#----------------------------------------------------------------------------
+
+def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True):
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
+ align = np.broadcast_to(np.asarray(align, dtype='float32'), [2])
+ image = prepare_texture_data(image)
+ height, width, channels = image.shape
+ size = zoom * [width, height]
+ pos = pos - size * align
+ if rint:
+ pos = np.rint(pos)
+ fmt = get_texture_format(image.dtype, channels)
+
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT)
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
+ gl.glRasterPos2f(pos[0], pos[1])
+ gl.glPixelZoom(zoom[0], -zoom[1])
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
+ gl.glDrawPixels(width, height, fmt.format, fmt.type, image)
+ gl.glPopClientAttrib()
+ gl.glPopAttrib()
+
+#----------------------------------------------------------------------------
+
+def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3):
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
+ dtype = np.dtype(dtype)
+ fmt = get_texture_format(dtype, channels)
+ image = np.empty([height, width, channels], dtype=dtype)
+
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image)
+ gl.glPopClientAttrib()
+ return np.flipud(image)
+
+#----------------------------------------------------------------------------
+
+class Texture:
+ def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True):
+ self.gl_id = None
+ self.bilinear = bilinear
+ self.mipmap = mipmap
+
+ # Determine size and dtype.
+ if image is not None:
+ image = prepare_texture_data(image)
+ self.height, self.width, self.channels = image.shape
+ self.dtype = image.dtype
+ else:
+ assert width is not None and height is not None
+ self.width = width
+ self.height = height
+ self.channels = channels if channels is not None else 3
+ self.dtype = np.dtype(dtype) if dtype is not None else np.uint8
+
+ # Validate size and dtype.
+ assert isinstance(self.width, int) and self.width >= 0
+ assert isinstance(self.height, int) and self.height >= 0
+ assert isinstance(self.channels, int) and self.channels >= 1
+ assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype)
+
+ # Create texture object.
+ self.gl_id = gl.glGenTextures(1)
+ with self.bind():
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST)
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST)
+ self.update(image)
+
+ def delete(self):
+ if self.gl_id is not None:
+ gl.glDeleteTextures([self.gl_id])
+ self.gl_id = None
+
+ def __del__(self):
+ try:
+ self.delete()
+ except:
+ pass
+
+ @contextlib.contextmanager
+ def bind(self):
+ prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D)
+ gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id)
+ yield
+ gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id)
+
+ def update(self, image):
+ if image is not None:
+ image = prepare_texture_data(image)
+ assert self.is_compatible(image=image)
+ with self.bind():
+ fmt = get_texture_format(self.dtype, self.channels)
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
+ gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image)
+ if self.mipmap:
+ gl.glGenerateMipmap(gl.GL_TEXTURE_2D)
+ gl.glPopClientAttrib()
+
+ def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0):
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
+ size = zoom * [self.width, self.height]
+ with self.bind():
+ gl.glPushAttrib(gl.GL_ENABLE_BIT)
+ gl.glEnable(gl.GL_TEXTURE_2D)
+ draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding)
+ gl.glPopAttrib()
+
+ def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements
+ if image is not None:
+ if image.ndim != 3:
+ return False
+ ih, iw, ic = image.shape
+ if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype):
+ return False
+ if width is not None and self.width != width:
+ return False
+ if height is not None and self.height != height:
+ return False
+ if channels is not None and self.channels != channels:
+ return False
+ if dtype is not None and self.dtype != dtype:
+ return False
+ return True
+
+#----------------------------------------------------------------------------
+
+class Framebuffer:
+ def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0):
+ self.texture = texture
+ self.gl_id = None
+ self.gl_color = None
+ self.gl_depth_stencil = None
+ self.msaa = msaa
+
+ # Determine size and dtype.
+ if texture is not None:
+ assert isinstance(self.texture, Texture)
+ self.width = texture.width
+ self.height = texture.height
+ self.channels = texture.channels
+ self.dtype = texture.dtype
+ else:
+ assert width is not None and height is not None
+ self.width = width
+ self.height = height
+ self.channels = channels if channels is not None else 4
+ self.dtype = np.dtype(dtype) if dtype is not None else np.float32
+
+ # Validate size and dtype.
+ assert isinstance(self.width, int) and self.width >= 0
+ assert isinstance(self.height, int) and self.height >= 0
+ assert isinstance(self.channels, int) and self.channels >= 1
+ assert width is None or width == self.width
+ assert height is None or height == self.height
+ assert channels is None or channels == self.channels
+ assert dtype is None or dtype == self.dtype
+
+ # Create framebuffer object.
+ self.gl_id = gl.glGenFramebuffers(1)
+ with self.bind():
+
+ # Setup color buffer.
+ if self.texture is not None:
+ assert self.msaa == 0
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0)
+ else:
+ fmt = get_texture_format(self.dtype, self.channels)
+ self.gl_color = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color)
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height)
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color)
+
+ # Setup depth/stencil buffer.
+ self.gl_depth_stencil = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil)
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height)
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil)
+
+ def delete(self):
+ if self.gl_id is not None:
+ gl.glDeleteFramebuffers([self.gl_id])
+ self.gl_id = None
+ if self.gl_color is not None:
+ gl.glDeleteRenderbuffers(1, [self.gl_color])
+ self.gl_color = None
+ if self.gl_depth_stencil is not None:
+ gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil])
+ self.gl_depth_stencil = None
+
+ def __del__(self):
+ try:
+ self.delete()
+ except:
+ pass
+
+ @contextlib.contextmanager
+ def bind(self):
+ prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
+ prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id)
+ if self.width is not None and self.height is not None:
+ gl.glViewport(0, 0, self.width, self.height)
+ yield
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo)
+
+ def blit(self, dst=None):
+ assert dst is None or isinstance(dst, Framebuffer)
+ with self.bind():
+ gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo)
+ gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST)
+
+#----------------------------------------------------------------------------
+
+def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1):
+ assert vertices.ndim == 2 and vertices.shape[1] == 2
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2])
+ color = np.broadcast_to(np.asarray(color, dtype='float32'), [3])
+ alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1)
+
+ gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT)
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT)
+ gl.glMatrixMode(gl.GL_MODELVIEW)
+ gl.glPushMatrix()
+
+ gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
+ gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY)
+ gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices)
+ gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices)
+ gl.glTranslate(pos[0], pos[1], 0)
+ gl.glScale(size[0], size[1], 1)
+ gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha)
+ gl.glDrawArrays(mode, 0, vertices.shape[0])
+
+ gl.glPopMatrix()
+ gl.glPopAttrib()
+ gl.glPopClientAttrib()
+
+#----------------------------------------------------------------------------
+
+def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0):
+ assert pos2 is None or size is None
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
+ pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None
+ size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32')
+ pos = pos - size * align
+ if rint:
+ pos = np.rint(pos)
+ rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2])
+ rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5)
+ if np.min(rounding) == 0:
+ rounding *= 0
+ vertices = _setup_rect(float(rounding[0]), float(rounding[1]))
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha)
+
+@functools.lru_cache(maxsize=10000)
+def _setup_rect(rx, ry):
+ t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64)
+ s = 1 - np.sin(t); c = 1 - np.cos(t)
+ x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx]
+ y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry]
+ v = np.stack([x, y], axis=-1).reshape(-1, 2)
+ return v.astype('float32')
+
+#----------------------------------------------------------------------------
+
+def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1):
+ hole = np.broadcast_to(np.asarray(hole, dtype='float32'), [])
+ vertices = _setup_circle(float(hole))
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha)
+
+@functools.lru_cache(maxsize=10000)
+def _setup_circle(hole):
+ t = np.linspace(0, np.pi * 2, 128)
+ s = np.sin(t); c = np.cos(t)
+ v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2)
+ return v.astype('float32')
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/gui_utils/glfw_window.py b/ThirdParty/eg3d/gui_utils/glfw_window.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeb96e8707db91c620825541c9b3c846b7362407
--- /dev/null
+++ b/ThirdParty/eg3d/gui_utils/glfw_window.py
@@ -0,0 +1,231 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import time
+import glfw
+import OpenGL.GL as gl
+from . import gl_utils
+
+#----------------------------------------------------------------------------
+
+class GlfwWindow: # pylint: disable=too-many-public-methods
+ def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
+ self._glfw_window = None
+ self._drawing_frame = False
+ self._frame_start_time = None
+ self._frame_delta = 0
+ self._fps_limit = None
+ self._vsync = None
+ self._skip_frames = 0
+ self._deferred_show = deferred_show
+ self._close_on_esc = close_on_esc
+ self._esc_pressed = False
+ self._drag_and_drop_paths = None
+ self._capture_next_frame = False
+ self._captured_frame = None
+
+ # Create window.
+ glfw.init()
+ glfw.window_hint(glfw.VISIBLE, False)
+ self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
+ self._attach_glfw_callbacks()
+ self.make_context_current()
+
+ # Adjust window.
+ self.set_vsync(False)
+ self.set_window_size(window_width, window_height)
+ if not self._deferred_show:
+ glfw.show_window(self._glfw_window)
+
+ def close(self):
+ if self._drawing_frame:
+ self.end_frame()
+ if self._glfw_window is not None:
+ glfw.destroy_window(self._glfw_window)
+ self._glfw_window = None
+ #glfw.terminate() # Commented out to play it nice with other glfw clients.
+
+ def __del__(self):
+ try:
+ self.close()
+ except:
+ pass
+
+ @property
+ def window_width(self):
+ return self.content_width
+
+ @property
+ def window_height(self):
+ return self.content_height + self.title_bar_height
+
+ @property
+ def content_width(self):
+ width, _height = glfw.get_window_size(self._glfw_window)
+ return width
+
+ @property
+ def content_height(self):
+ _width, height = glfw.get_window_size(self._glfw_window)
+ return height
+
+ @property
+ def title_bar_height(self):
+ _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
+ return top
+
+ @property
+ def monitor_width(self):
+ _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
+ return width
+
+ @property
+ def monitor_height(self):
+ _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
+ return height
+
+ @property
+ def frame_delta(self):
+ return self._frame_delta
+
+ def set_title(self, title):
+ glfw.set_window_title(self._glfw_window, title)
+
+ def set_window_size(self, width, height):
+ width = min(width, self.monitor_width)
+ height = min(height, self.monitor_height)
+ glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
+ if width == self.monitor_width and height == self.monitor_height:
+ self.maximize()
+
+ def set_content_size(self, width, height):
+ self.set_window_size(width, height + self.title_bar_height)
+
+ def maximize(self):
+ glfw.maximize_window(self._glfw_window)
+
+ def set_position(self, x, y):
+ glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
+
+ def center(self):
+ self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
+
+ def set_vsync(self, vsync):
+ vsync = bool(vsync)
+ if vsync != self._vsync:
+ glfw.swap_interval(1 if vsync else 0)
+ self._vsync = vsync
+
+ def set_fps_limit(self, fps_limit):
+ self._fps_limit = int(fps_limit)
+
+ def should_close(self):
+ return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
+
+ def skip_frame(self):
+ self.skip_frames(1)
+
+ def skip_frames(self, num): # Do not update window for the next N frames.
+ self._skip_frames = max(self._skip_frames, int(num))
+
+ def is_skipping_frames(self):
+ return self._skip_frames > 0
+
+ def capture_next_frame(self):
+ self._capture_next_frame = True
+
+ def pop_captured_frame(self):
+ frame = self._captured_frame
+ self._captured_frame = None
+ return frame
+
+ def pop_drag_and_drop_paths(self):
+ paths = self._drag_and_drop_paths
+ self._drag_and_drop_paths = None
+ return paths
+
+ def draw_frame(self): # To be overridden by subclass.
+ self.begin_frame()
+ # Rendering code goes here.
+ self.end_frame()
+
+ def make_context_current(self):
+ if self._glfw_window is not None:
+ glfw.make_context_current(self._glfw_window)
+
+ def begin_frame(self):
+ # End previous frame.
+ if self._drawing_frame:
+ self.end_frame()
+
+ # Apply FPS limit.
+ if self._frame_start_time is not None and self._fps_limit is not None:
+ delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
+ if delay > 0:
+ time.sleep(delay)
+ cur_time = time.perf_counter()
+ if self._frame_start_time is not None:
+ self._frame_delta = cur_time - self._frame_start_time
+ self._frame_start_time = cur_time
+
+ # Process events.
+ glfw.poll_events()
+
+ # Begin frame.
+ self._drawing_frame = True
+ self.make_context_current()
+
+ # Initialize GL state.
+ gl.glViewport(0, 0, self.content_width, self.content_height)
+ gl.glMatrixMode(gl.GL_PROJECTION)
+ gl.glLoadIdentity()
+ gl.glTranslate(-1, 1, 0)
+ gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
+ gl.glMatrixMode(gl.GL_MODELVIEW)
+ gl.glLoadIdentity()
+ gl.glEnable(gl.GL_BLEND)
+ gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
+
+ # Clear.
+ gl.glClearColor(0, 0, 0, 1)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
+
+ def end_frame(self):
+ assert self._drawing_frame
+ self._drawing_frame = False
+
+ # Skip frames if requested.
+ if self._skip_frames > 0:
+ self._skip_frames -= 1
+ return
+
+ # Capture frame if requested.
+ if self._capture_next_frame:
+ self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
+ self._capture_next_frame = False
+
+ # Update window.
+ if self._deferred_show:
+ glfw.show_window(self._glfw_window)
+ self._deferred_show = False
+ glfw.swap_buffers(self._glfw_window)
+
+ def _attach_glfw_callbacks(self):
+ glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
+ glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
+
+ def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
+ if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
+ self._esc_pressed = True
+
+ def _glfw_drop_callback(self, _window, paths):
+ self._drag_and_drop_paths = paths
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/gui_utils/imgui_utils.py b/ThirdParty/eg3d/gui_utils/imgui_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a8357caf20493956769984f32776441beefd27
--- /dev/null
+++ b/ThirdParty/eg3d/gui_utils/imgui_utils.py
@@ -0,0 +1,171 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import contextlib
+import imgui
+
+#----------------------------------------------------------------------------
+
+def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
+ s = imgui.get_style()
+ s.window_padding = [spacing, spacing]
+ s.item_spacing = [spacing, spacing]
+ s.item_inner_spacing = [spacing, spacing]
+ s.columns_min_spacing = spacing
+ s.indent_spacing = indent
+ s.scrollbar_size = scrollbar
+ s.frame_padding = [4, 3]
+ s.window_border_size = 1
+ s.child_border_size = 1
+ s.popup_border_size = 1
+ s.frame_border_size = 1
+ s.window_rounding = 0
+ s.child_rounding = 0
+ s.popup_rounding = 3
+ s.frame_rounding = 3
+ s.scrollbar_rounding = 3
+ s.grab_rounding = 3
+
+ getattr(imgui, f'style_colors_{color_scheme}')(s)
+ c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
+ c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
+ s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
+
+#----------------------------------------------------------------------------
+
+@contextlib.contextmanager
+def grayed_out(cond=True):
+ if cond:
+ s = imgui.get_style()
+ text = s.colors[imgui.COLOR_TEXT_DISABLED]
+ grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
+ back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
+ imgui.push_style_color(imgui.COLOR_TEXT, *text)
+ imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
+ imgui.push_style_color(imgui.COLOR_BUTTON, *back)
+ imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
+ imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
+ imgui.push_style_color(imgui.COLOR_HEADER, *back)
+ imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
+ imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
+ imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
+ yield
+ imgui.pop_style_color(14)
+ else:
+ yield
+
+#----------------------------------------------------------------------------
+
+@contextlib.contextmanager
+def item_width(width=None):
+ if width is not None:
+ imgui.push_item_width(width)
+ yield
+ imgui.pop_item_width()
+ else:
+ yield
+
+#----------------------------------------------------------------------------
+
+def scoped_by_object_id(method):
+ def decorator(self, *args, **kwargs):
+ imgui.push_id(str(id(self)))
+ res = method(self, *args, **kwargs)
+ imgui.pop_id()
+ return res
+ return decorator
+
+#----------------------------------------------------------------------------
+
+def button(label, width=0, enabled=True):
+ with grayed_out(not enabled):
+ clicked = imgui.button(label, width=width)
+ clicked = clicked and enabled
+ return clicked
+
+#----------------------------------------------------------------------------
+
+def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
+ expanded = False
+ if show:
+ if default:
+ flags |= imgui.TREE_NODE_DEFAULT_OPEN
+ if not enabled:
+ flags |= imgui.TREE_NODE_LEAF
+ with grayed_out(not enabled):
+ expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
+ expanded = expanded and enabled
+ return expanded, visible
+
+#----------------------------------------------------------------------------
+
+def popup_button(label, width=0, enabled=True):
+ if button(label, width, enabled):
+ imgui.open_popup(label)
+ opened = imgui.begin_popup(label)
+ return opened
+
+#----------------------------------------------------------------------------
+
+def input_text(label, value, buffer_length, flags, width=None, help_text=''):
+ old_value = value
+ color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
+ if value == '':
+ color[-1] *= 0.5
+ with item_width(width):
+ imgui.push_style_color(imgui.COLOR_TEXT, *color)
+ value = value if value != '' else help_text
+ changed, value = imgui.input_text(label, value, buffer_length, flags)
+ value = value if value != help_text else ''
+ imgui.pop_style_color(1)
+ if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
+ changed = (value != old_value)
+ return changed, value
+
+#----------------------------------------------------------------------------
+
+def drag_previous_control(enabled=True):
+ dragging = False
+ dx = 0
+ dy = 0
+ if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
+ if enabled:
+ dragging = True
+ dx, dy = imgui.get_mouse_drag_delta()
+ imgui.reset_mouse_drag_delta()
+ imgui.end_drag_drop_source()
+ return dragging, dx, dy
+
+#----------------------------------------------------------------------------
+
+def drag_button(label, width=0, enabled=True):
+ clicked = button(label, width=width, enabled=enabled)
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
+ return clicked, dragging, dx, dy
+
+#----------------------------------------------------------------------------
+
+def drag_hidden_window(label, x, y, width, height, enabled=True):
+ imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
+ imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
+ imgui.set_next_window_position(x, y)
+ imgui.set_next_window_size(width, height)
+ imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
+ imgui.end()
+ imgui.pop_style_color(2)
+ return dragging, dx, dy
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/gui_utils/imgui_window.py b/ThirdParty/eg3d/gui_utils/imgui_window.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1a6382b41c593c5ea4d9d2888c716282e575ec
--- /dev/null
+++ b/ThirdParty/eg3d/gui_utils/imgui_window.py
@@ -0,0 +1,105 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import os
+import imgui
+import imgui.integrations.glfw
+
+from . import glfw_window
+from . import imgui_utils
+from . import text_utils
+
+#----------------------------------------------------------------------------
+
+class ImguiWindow(glfw_window.GlfwWindow):
+ def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
+ if font is None:
+ font = text_utils.get_default_font()
+ font_sizes = {int(size) for size in font_sizes}
+ super().__init__(title=title, **glfw_kwargs)
+
+ # Init fields.
+ self._imgui_context = None
+ self._imgui_renderer = None
+ self._imgui_fonts = None
+ self._cur_font_size = max(font_sizes)
+
+ # Delete leftover imgui.ini to avoid unexpected behavior.
+ if os.path.isfile('imgui.ini'):
+ os.remove('imgui.ini')
+
+ # Init ImGui.
+ self._imgui_context = imgui.create_context()
+ self._imgui_renderer = _GlfwRenderer(self._glfw_window)
+ self._attach_glfw_callbacks()
+ imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
+ imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
+ self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
+ self._imgui_renderer.refresh_font_texture()
+
+ def close(self):
+ self.make_context_current()
+ self._imgui_fonts = None
+ if self._imgui_renderer is not None:
+ self._imgui_renderer.shutdown()
+ self._imgui_renderer = None
+ if self._imgui_context is not None:
+ #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
+ self._imgui_context = None
+ super().close()
+
+ def _glfw_key_callback(self, *args):
+ super()._glfw_key_callback(*args)
+ self._imgui_renderer.keyboard_callback(*args)
+
+ @property
+ def font_size(self):
+ return self._cur_font_size
+
+ @property
+ def spacing(self):
+ return round(self._cur_font_size * 0.4)
+
+ def set_font_size(self, target): # Applied on next frame.
+ self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
+
+ def begin_frame(self):
+ # Begin glfw frame.
+ super().begin_frame()
+
+ # Process imgui events.
+ self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
+ if self.content_width > 0 and self.content_height > 0:
+ self._imgui_renderer.process_inputs()
+
+ # Begin imgui frame.
+ imgui.new_frame()
+ imgui.push_font(self._imgui_fonts[self._cur_font_size])
+ imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
+
+ def end_frame(self):
+ imgui.pop_font()
+ imgui.render()
+ imgui.end_frame()
+ self._imgui_renderer.render(imgui.get_draw_data())
+ super().end_frame()
+
+#----------------------------------------------------------------------------
+# Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
+
+class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.mouse_wheel_multiplier = 1
+
+ def scroll_callback(self, window, x_offset, y_offset):
+ self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/gui_utils/text_utils.py b/ThirdParty/eg3d/gui_utils/text_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e64a34d1287d58960141fa06a8e76446cd9cebc8
--- /dev/null
+++ b/ThirdParty/eg3d/gui_utils/text_utils.py
@@ -0,0 +1,125 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import functools
+from typing import Optional
+
+import dnnlib
+import numpy as np
+import PIL.Image
+import PIL.ImageFont
+import scipy.ndimage
+
+from . import gl_utils
+
+#----------------------------------------------------------------------------
+
+def get_default_font():
+ url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
+ return dnnlib.util.open_url(url, return_filename=True)
+
+#----------------------------------------------------------------------------
+
+@functools.lru_cache(maxsize=None)
+def get_pil_font(font=None, size=32):
+ if font is None:
+ font = get_default_font()
+ return PIL.ImageFont.truetype(font=font, size=size)
+
+#----------------------------------------------------------------------------
+
+def get_array(string, *, dropshadow_radius: int=None, **kwargs):
+ if dropshadow_radius is not None:
+ offset_x = int(np.ceil(dropshadow_radius*2/3))
+ offset_y = int(np.ceil(dropshadow_radius*2/3))
+ return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
+ else:
+ return _get_array_priv(string, **kwargs)
+
+@functools.lru_cache(maxsize=10000)
+def _get_array_priv(
+ string: str, *,
+ size: int = 32,
+ max_width: Optional[int]=None,
+ max_height: Optional[int]=None,
+ min_size=10,
+ shrink_coef=0.8,
+ dropshadow_radius: int=None,
+ offset_x: int=None,
+ offset_y: int=None,
+ **kwargs
+):
+ cur_size = size
+ array = None
+ while True:
+ if dropshadow_radius is not None:
+ # separate implementation for dropshadow text rendering
+ array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
+ else:
+ array = _get_array_impl(string, size=cur_size, **kwargs)
+ height, width, _ = array.shape
+ if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
+ break
+ cur_size = max(int(cur_size * shrink_coef), min_size)
+ return array
+
+#----------------------------------------------------------------------------
+
+@functools.lru_cache(maxsize=10000)
+def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
+ pil_font = get_pil_font(font=font, size=size)
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
+ width = max(line.shape[1] for line in lines)
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
+ line_spacing = line_pad if line_pad is not None else size // 2
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
+ mask = np.concatenate(lines, axis=0)
+ alpha = mask
+ if outline > 0:
+ mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
+ alpha = mask.astype(np.float32) / 255
+ alpha = scipy.ndimage.gaussian_filter(alpha, outline)
+ alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
+ alpha = np.maximum(alpha, mask)
+ return np.stack([mask, alpha], axis=-1)
+
+#----------------------------------------------------------------------------
+
+@functools.lru_cache(maxsize=10000)
+def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
+ assert (offset_x > 0) and (offset_y > 0)
+ pil_font = get_pil_font(font=font, size=size)
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
+ width = max(line.shape[1] for line in lines)
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
+ line_spacing = line_pad if line_pad is not None else size // 2
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
+ mask = np.concatenate(lines, axis=0)
+ alpha = mask
+
+ mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
+ alpha = mask.astype(np.float32) / 255
+ alpha = scipy.ndimage.gaussian_filter(alpha, radius)
+ alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
+ alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
+ alpha = np.maximum(alpha, mask)
+ return np.stack([mask, alpha], axis=-1)
+
+#----------------------------------------------------------------------------
+
+@functools.lru_cache(maxsize=10000)
+def get_texture(string, bilinear=True, mipmap=True, **kwargs):
+ return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/legacy.py b/ThirdParty/eg3d/legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..f30944a15c8f7da114c3b1d94da8c31b1ed13ae8
--- /dev/null
+++ b/ThirdParty/eg3d/legacy.py
@@ -0,0 +1,325 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Converting legacy network pickle into the new format."""
+
+import click
+import pickle
+import re
+import copy
+import numpy as np
+import torch
+import dnnlib
+from torch_utils import misc
+
+#----------------------------------------------------------------------------
+
+def load_network_pkl(f, force_fp16=False):
+ data = _LegacyUnpickler(f).load()
+
+ # Legacy TensorFlow pickle => convert.
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
+ tf_G, tf_D, tf_Gs = data
+ G = convert_tf_generator(tf_G)
+ D = convert_tf_discriminator(tf_D)
+ G_ema = convert_tf_generator(tf_Gs)
+ data = dict(G=G, D=D, G_ema=G_ema)
+
+ # Add missing fields.
+ if 'training_set_kwargs' not in data:
+ data['training_set_kwargs'] = None
+ if 'augment_pipe' not in data:
+ data['augment_pipe'] = None
+
+ # Validate contents.
+ assert isinstance(data['G'], torch.nn.Module)
+ assert isinstance(data['D'], torch.nn.Module)
+ assert isinstance(data['G_ema'], torch.nn.Module)
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
+
+ # Force FP16.
+ if force_fp16:
+ for key in ['G', 'D', 'G_ema']:
+ old = data[key]
+ kwargs = copy.deepcopy(old.init_kwargs)
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
+ fp16_kwargs.num_fp16_res = 4
+ fp16_kwargs.conv_clamp = 256
+ if kwargs != old.init_kwargs:
+ new = type(old)(**kwargs).eval().requires_grad_(False)
+ misc.copy_params_and_buffers(old, new, require_all=True)
+ data[key] = new
+ return data
+
+#----------------------------------------------------------------------------
+
+class _TFNetworkStub(dnnlib.EasyDict):
+ pass
+
+class _LegacyUnpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ if module == 'dnnlib.tflib.network' and name == 'Network':
+ return _TFNetworkStub
+ return super().find_class(module, name)
+
+#----------------------------------------------------------------------------
+
+def _collect_tf_params(tf_net):
+ # pylint: disable=protected-access
+ tf_params = dict()
+ def recurse(prefix, tf_net):
+ for name, value in tf_net.variables:
+ tf_params[prefix + name] = value
+ for name, comp in tf_net.components.items():
+ recurse(prefix + name + '/', comp)
+ recurse('', tf_net)
+ return tf_params
+
+#----------------------------------------------------------------------------
+
+def _populate_module_params(module, *patterns):
+ for name, tensor in misc.named_params_and_buffers(module):
+ found = False
+ value = None
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
+ match = re.fullmatch(pattern, name)
+ if match:
+ found = True
+ if value_fn is not None:
+ value = value_fn(*match.groups())
+ break
+ try:
+ assert found
+ if value is not None:
+ tensor.copy_(torch.from_numpy(np.array(value)))
+ except:
+ print(name, list(tensor.shape))
+ raise
+
+#----------------------------------------------------------------------------
+
+def convert_tf_generator(tf_G):
+ if tf_G.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_G.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None, none=None):
+ known_kwargs.add(tf_name)
+ val = tf_kwargs.get(tf_name, default)
+ return val if val is not None else none
+
+ # Convert kwargs.
+ from training import networks_stylegan2
+ network_class = networks_stylegan2.Generator
+ kwargs = dnnlib.EasyDict(
+ z_dim = kwarg('latent_size', 512),
+ c_dim = kwarg('label_size', 0),
+ w_dim = kwarg('dlatent_size', 512),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ architecture = kwarg('architecture', 'skip'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ use_noise = kwarg('use_noise', True),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 8),
+ embed_features = kwarg('label_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('truncation_psi')
+ kwarg('truncation_cutoff')
+ kwarg('style_mixing_prob')
+ kwarg('structure')
+ kwarg('conditioning')
+ kwarg('fused_modconv')
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_G)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
+ kwargs.synthesis.kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ G = network_class(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ # pylint: disable=f-string-without-interpolation
+ _populate_module_params(G,
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'.*\.resample_filter', None,
+ r'.*\.act_filter', None,
+ )
+ return G
+
+#----------------------------------------------------------------------------
+
+def convert_tf_discriminator(tf_D):
+ if tf_D.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_D.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None):
+ known_kwargs.add(tf_name)
+ return tf_kwargs.get(tf_name, default)
+
+ # Convert kwargs.
+ kwargs = dnnlib.EasyDict(
+ c_dim = kwarg('label_size', 0),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ architecture = kwarg('architecture', 'resnet'),
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ cmap_dim = kwarg('mapping_fmaps', None),
+ block_kwargs = dnnlib.EasyDict(
+ activation = kwarg('nonlinearity', 'lrelu'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ freeze_layers = kwarg('freeze_layers', 0),
+ ),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 0),
+ embed_features = kwarg('mapping_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
+ ),
+ epilogue_kwargs = dnnlib.EasyDict(
+ mbstd_group_size = kwarg('mbstd_group_size', None),
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('structure')
+ kwarg('conditioning')
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_D)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
+ kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ from training import networks_stylegan2
+ D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ # pylint: disable=f-string-without-interpolation
+ _populate_module_params(D,
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
+ r'.*\.resample_filter', None,
+ )
+ return D
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--source', help='Input pickle', required=True, metavar='PATH')
+@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
+@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
+def convert_network_pickle(source, dest, force_fp16):
+ """Convert legacy network pickle into the native PyTorch format.
+
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
+
+ Example:
+
+ \b
+ python legacy.py \\
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
+ --dest=stylegan2-cat-config-f.pkl
+ """
+ print(f'Loading "{source}"...')
+ with dnnlib.util.open_url(source) as f:
+ data = load_network_pkl(f, force_fp16=force_fp16)
+ print(f'Saving "{dest}"...')
+ with open(dest, 'wb') as f:
+ pickle.dump(data, f)
+ print('Done.')
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/metrics/__init__.py b/ThirdParty/eg3d/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/ThirdParty/eg3d/metrics/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/ThirdParty/eg3d/metrics/equivariance.py b/ThirdParty/eg3d/metrics/equivariance.py
new file mode 100644
index 0000000000000000000000000000000000000000..4609296593dd60cf0a1afa28ae4abb17d5b23576
--- /dev/null
+++ b/ThirdParty/eg3d/metrics/equivariance.py
@@ -0,0 +1,269 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
+"Alias-Free Generative Adversarial Networks"."""
+
+import copy
+import numpy as np
+import torch
+import torch.fft
+from torch_utils.ops import upfirdn2d
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+# Utilities.
+
+def sinc(x):
+ y = (x * np.pi).abs()
+ z = torch.sin(y) / y.clamp(1e-30, float('inf'))
+ return torch.where(y < 1e-30, torch.ones_like(x), z)
+
+def lanczos_window(x, a):
+ x = x.abs() / a
+ return torch.where(x < 1, sinc(x), torch.zeros_like(x))
+
+def rotation_matrix(angle):
+ angle = torch.as_tensor(angle).to(torch.float32)
+ mat = torch.eye(3, device=angle.device)
+ mat[0, 0] = angle.cos()
+ mat[0, 1] = angle.sin()
+ mat[1, 0] = -angle.sin()
+ mat[1, 1] = angle.cos()
+ return mat
+
+#----------------------------------------------------------------------------
+# Apply integer translation to a batch of 2D images. Corresponds to the
+# operator T_x in Appendix E.1.
+
+def apply_integer_translation(x, tx, ty):
+ _N, _C, H, W = x.shape
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
+ ix = tx.round().to(torch.int64)
+ iy = ty.round().to(torch.int64)
+
+ z = torch.zeros_like(x)
+ m = torch.zeros_like(x)
+ if abs(ix) < W and abs(iy) < H:
+ y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
+ z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
+ m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
+ return z, m
+
+#----------------------------------------------------------------------------
+# Apply integer translation to a batch of 2D images. Corresponds to the
+# operator T_x in Appendix E.2.
+
+def apply_fractional_translation(x, tx, ty, a=3):
+ _N, _C, H, W = x.shape
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
+ ix = tx.floor().to(torch.int64)
+ iy = ty.floor().to(torch.int64)
+ fx = tx - ix
+ fy = ty - iy
+ b = a - 1
+
+ z = torch.zeros_like(x)
+ zx0 = max(ix - b, 0)
+ zy0 = max(iy - b, 0)
+ zx1 = min(ix + a, 0) + W
+ zy1 = min(iy + a, 0) + H
+ if zx0 < zx1 and zy0 < zy1:
+ taps = torch.arange(a * 2, device=x.device) - b
+ filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
+ filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
+ y = x
+ y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
+ y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
+ y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
+ z[:, :, zy0:zy1, zx0:zx1] = y
+
+ m = torch.zeros_like(x)
+ mx0 = max(ix + a, 0)
+ my0 = max(iy + a, 0)
+ mx1 = min(ix - b, 0) + W
+ my1 = min(iy - b, 0) + H
+ if mx0 < mx1 and my0 < my1:
+ m[:, :, my0:my1, mx0:mx1] = 1
+ return z, m
+
+#----------------------------------------------------------------------------
+# Construct an oriented low-pass filter that applies the appropriate
+# bandlimit with respect to the input and output of the given affine 2D
+# image transformation.
+
+def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
+ assert a <= amax < aflt
+ mat = torch.as_tensor(mat).to(torch.float32)
+
+ # Construct 2D filter taps in input & output coordinate spaces.
+ taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
+ yi, xi = torch.meshgrid(taps, taps)
+ xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
+
+ # Convolution of two oriented 2D sinc filters.
+ fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
+ fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
+ f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
+
+ # Convolution of two oriented 2D Lanczos windows.
+ wi = lanczos_window(xi, a) * lanczos_window(yi, a)
+ wo = lanczos_window(xo, a) * lanczos_window(yo, a)
+ w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
+
+ # Construct windowed FIR filter.
+ f = f * w
+
+ # Finalize.
+ c = (aflt - amax) * up
+ f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
+ f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
+ f = f / f.sum([0,2], keepdim=True) / (up ** 2)
+ f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
+ return f
+
+#----------------------------------------------------------------------------
+# Apply the given affine transformation to a batch of 2D images.
+
+def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
+ _N, _C, H, W = x.shape
+ mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
+
+ # Construct filter.
+ f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
+ assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
+ p = f.shape[0] // 2
+
+ # Construct sampling grid.
+ theta = mat.inverse()
+ theta[:2, 2] *= 2
+ theta[0, 2] += 1 / up / W
+ theta[1, 2] += 1 / up / H
+ theta[0, :] *= W / (W + p / up * 2)
+ theta[1, :] *= H / (H + p / up * 2)
+ theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
+ g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
+
+ # Resample image.
+ y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
+ z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+ # Form mask.
+ m = torch.zeros_like(y)
+ c = p * 2 + 1
+ m[:, :, c:-c, c:-c] = 1
+ m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
+ return z, m
+
+#----------------------------------------------------------------------------
+# Apply fractional rotation to a batch of 2D images. Corresponds to the
+# operator R_\alpha in Appendix E.3.
+
+def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
+ mat = rotation_matrix(angle)
+ return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
+
+#----------------------------------------------------------------------------
+# Modify the frequency content of a batch of 2D images as if they had undergo
+# fractional rotation -- but without actually rotating them. Corresponds to
+# the operator R^*_\alpha in Appendix E.3.
+
+def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
+ mat = rotation_matrix(-angle)
+ f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
+ y = upfirdn2d.filter2d(x=x, f=f)
+ m = torch.zeros_like(y)
+ c = f.shape[0] // 2
+ m[:, :, c:-c, c:-c] = 1
+ return y, m
+
+#----------------------------------------------------------------------------
+# Compute the selected equivariance metrics for the given generator.
+
+def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
+ assert compute_eqt_int or compute_eqt_frac or compute_eqr
+
+ # Setup generator and labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ I = torch.eye(3, device=opts.device)
+ M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
+ if M is None:
+ raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
+
+ # Sampling loop.
+ sums = None
+ progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ s = []
+
+ # Randomize noise buffers, if any.
+ for name, buf in G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Run mapping network.
+ z = torch.randn([batch_size, G.z_dim], device=opts.device)
+ c = next(c_iter)
+ ws = G.mapping(z=z, c=c)
+
+ # Generate reference image.
+ M[:] = I
+ orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+
+ # Integer translation (EQ-T).
+ if compute_eqt_int:
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
+ t = (t * G.img_resolution).round() / G.img_resolution
+ M[:] = I
+ M[:2, 2] = -t
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, mask = apply_integer_translation(orig, t[0], t[1])
+ s += [(ref - img).square() * mask, mask]
+
+ # Fractional translation (EQ-T_frac).
+ if compute_eqt_frac:
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
+ M[:] = I
+ M[:2, 2] = -t
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, mask = apply_fractional_translation(orig, t[0], t[1])
+ s += [(ref - img).square() * mask, mask]
+
+ # Rotation (EQ-R).
+ if compute_eqr:
+ angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
+ M[:] = rotation_matrix(-angle)
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, ref_mask = apply_fractional_rotation(orig, angle)
+ pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
+ mask = ref_mask * pseudo_mask
+ s += [(ref - pseudo).square() * mask, mask]
+
+ # Accumulate results.
+ s = torch.stack([x.to(torch.float64).sum() for x in s])
+ sums = sums + s if sums is not None else s
+ progress.update(num_samples)
+
+ # Compute PSNRs.
+ if opts.num_gpus > 1:
+ torch.distributed.all_reduce(sums)
+ sums = sums.cpu()
+ mses = sums[0::2] / sums[1::2]
+ psnrs = np.log10(2) * 20 - mses.log10() * 10
+ psnrs = tuple(psnrs.numpy())
+ return psnrs[0] if len(psnrs) == 1 else psnrs
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/metrics/frechet_inception_distance.py b/ThirdParty/eg3d/metrics/frechet_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2944eb21dbb88d2f383991ff88f557513b38168
--- /dev/null
+++ b/ThirdParty/eg3d/metrics/frechet_inception_distance.py
@@ -0,0 +1,43 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Frechet Inception Distance (FID) from the paper
+"GANs trained by a two time-scale update rule converge to a local Nash
+equilibrium". Matches the original implementation by Heusel et al. at
+https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
+
+import numpy as np
+import scipy.linalg
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_fid(opts, max_real, num_gen):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
+
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ m = np.square(mu_gen - mu_real).sum()
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
+ return float(fid)
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/metrics/inception_score.py b/ThirdParty/eg3d/metrics/inception_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e5e247280f76471819550295bf2fc5ea3f7b42e
--- /dev/null
+++ b/ThirdParty/eg3d/metrics/inception_score.py
@@ -0,0 +1,40 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Inception Score (IS) from the paper "Improved techniques for training
+GANs". Matches the original implementation by Salimans et al. at
+https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_is(opts, num_gen, num_splits):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
+
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan'), float('nan')
+
+ scores = []
+ for i in range(num_splits):
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
+ kl = np.mean(np.sum(kl, axis=1))
+ scores.append(np.exp(kl))
+ return float(np.mean(scores)), float(np.std(scores))
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/metrics/kernel_inception_distance.py b/ThirdParty/eg3d/metrics/kernel_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..48906eba23a7d29ba912b7d209f83fba6d0b9f37
--- /dev/null
+++ b/ThirdParty/eg3d/metrics/kernel_inception_distance.py
@@ -0,0 +1,48 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
+GANs". Matches the original implementation by Binkowski et al. at
+https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ n = real_features.shape[1]
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
+ t = 0
+ for _subset_idx in range(num_subsets):
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
+ b = (x @ y.T / n + 1) ** 3
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
+ kid = t / num_subsets / m
+ return float(kid)
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/metrics/metric_main.py b/ThirdParty/eg3d/metrics/metric_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..52318ee48a523f30e7eace0b62b936c7826ffc56
--- /dev/null
+++ b/ThirdParty/eg3d/metrics/metric_main.py
@@ -0,0 +1,155 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Main API for computing and reporting quality metrics."""
+
+import os
+import time
+import json
+import torch
+import dnnlib
+
+from . import metric_utils
+from . import frechet_inception_distance
+from . import kernel_inception_distance
+from . import precision_recall
+from . import perceptual_path_length
+from . import inception_score
+from . import equivariance
+
+#----------------------------------------------------------------------------
+
+_metric_dict = dict() # name => fn
+
+def register_metric(fn):
+ assert callable(fn)
+ _metric_dict[fn.__name__] = fn
+ return fn
+
+def is_valid_metric(metric):
+ return metric in _metric_dict
+
+def list_valid_metrics():
+ return list(_metric_dict.keys())
+
+#----------------------------------------------------------------------------
+
+def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
+ assert is_valid_metric(metric)
+ opts = metric_utils.MetricOptions(**kwargs)
+
+ # Calculate.
+ start_time = time.time()
+ results = _metric_dict[metric](opts)
+ total_time = time.time() - start_time
+
+ # Broadcast results.
+ for key, value in list(results.items()):
+ if opts.num_gpus > 1:
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
+ torch.distributed.broadcast(tensor=value, src=0)
+ value = float(value.cpu())
+ results[key] = value
+
+ # Decorate with metadata.
+ return dnnlib.EasyDict(
+ results = dnnlib.EasyDict(results),
+ metric = metric,
+ total_time = total_time,
+ total_time_str = dnnlib.util.format_time(total_time),
+ num_gpus = opts.num_gpus,
+ )
+
+#----------------------------------------------------------------------------
+
+def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
+ metric = result_dict['metric']
+ assert is_valid_metric(metric)
+ if run_dir is not None and snapshot_pkl is not None:
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
+
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
+ print(jsonl_line)
+ if run_dir is not None and os.path.isdir(run_dir):
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
+ f.write(jsonl_line + '\n')
+
+#----------------------------------------------------------------------------
+# Recommended metrics.
+
+@register_metric
+def fid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
+ return dict(fid50k_full=fid)
+
+@register_metric
+def kid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k_full=kid)
+
+@register_metric
+def pr50k3_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
+
+@register_metric
+def ppl2_wend(opts):
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
+ return dict(ppl2_wend=ppl)
+
+@register_metric
+def eqt50k_int(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
+ return dict(eqt50k_int=psnr)
+
+@register_metric
+def eqt50k_frac(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
+ return dict(eqt50k_frac=psnr)
+
+@register_metric
+def eqr50k(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
+ return dict(eqr50k=psnr)
+
+#----------------------------------------------------------------------------
+# Legacy metrics.
+
+@register_metric
+def fid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
+ return dict(fid50k=fid)
+
+@register_metric
+def kid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k=kid)
+
+@register_metric
+def pr50k3(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
+
+@register_metric
+def is50k(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
+ return dict(is50k_mean=mean, is50k_std=std)
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/metrics/metric_utils.py b/ThirdParty/eg3d/metrics/metric_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..212cb7d38fabf6c7b60c55a0fa0a07560ac602b2
--- /dev/null
+++ b/ThirdParty/eg3d/metrics/metric_utils.py
@@ -0,0 +1,281 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Miscellaneous utilities used internally by the quality metrics."""
+
+import os
+import time
+import hashlib
+import pickle
+import copy
+import uuid
+import numpy as np
+import torch
+import dnnlib
+
+#----------------------------------------------------------------------------
+
+class MetricOptions:
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
+ assert 0 <= rank < num_gpus
+ self.G = G
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
+ self.num_gpus = num_gpus
+ self.rank = rank
+ self.device = device if device is not None else torch.device('cuda', rank)
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
+ self.cache = cache
+
+#----------------------------------------------------------------------------
+
+_feature_detector_cache = dict()
+
+def get_feature_detector_name(url):
+ return os.path.splitext(url.split('/')[-1])[0]
+
+def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
+ assert 0 <= rank < num_gpus
+ key = (url, device)
+ if key not in _feature_detector_cache:
+ is_leader = (rank == 0)
+ if not is_leader and num_gpus > 1:
+ torch.distributed.barrier() # leader goes first
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
+ _feature_detector_cache[key] = pickle.load(f).to(device)
+ if is_leader and num_gpus > 1:
+ torch.distributed.barrier() # others follow
+ return _feature_detector_cache[key]
+
+#----------------------------------------------------------------------------
+
+def iterate_random_labels(opts, batch_size):
+ if opts.G.c_dim == 0:
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
+ while True:
+ yield c
+ else:
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ while True:
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
+ yield c
+
+#----------------------------------------------------------------------------
+
+class FeatureStats:
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
+ self.capture_all = capture_all
+ self.capture_mean_cov = capture_mean_cov
+ self.max_items = max_items
+ self.num_items = 0
+ self.num_features = None
+ self.all_features = None
+ self.raw_mean = None
+ self.raw_cov = None
+
+ def set_num_features(self, num_features):
+ if self.num_features is not None:
+ assert num_features == self.num_features
+ else:
+ self.num_features = num_features
+ self.all_features = []
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
+
+ def is_full(self):
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
+
+ def append(self, x):
+ x = np.asarray(x, dtype=np.float32)
+ assert x.ndim == 2
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
+ if self.num_items >= self.max_items:
+ return
+ x = x[:self.max_items - self.num_items]
+
+ self.set_num_features(x.shape[1])
+ self.num_items += x.shape[0]
+ if self.capture_all:
+ self.all_features.append(x)
+ if self.capture_mean_cov:
+ x64 = x.astype(np.float64)
+ self.raw_mean += x64.sum(axis=0)
+ self.raw_cov += x64.T @ x64
+
+ def append_torch(self, x, num_gpus=1, rank=0):
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
+ assert 0 <= rank < num_gpus
+ if num_gpus > 1:
+ ys = []
+ for src in range(num_gpus):
+ y = x.clone()
+ torch.distributed.broadcast(y, src=src)
+ ys.append(y)
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
+ self.append(x.cpu().numpy())
+
+ def get_all(self):
+ assert self.capture_all
+ return np.concatenate(self.all_features, axis=0)
+
+ def get_all_torch(self):
+ return torch.from_numpy(self.get_all())
+
+ def get_mean_cov(self):
+ assert self.capture_mean_cov
+ mean = self.raw_mean / self.num_items
+ cov = self.raw_cov / self.num_items
+ cov = cov - np.outer(mean, mean)
+ return mean, cov
+
+ def save(self, pkl_file):
+ with open(pkl_file, 'wb') as f:
+ pickle.dump(self.__dict__, f)
+
+ @staticmethod
+ def load(pkl_file):
+ with open(pkl_file, 'rb') as f:
+ s = dnnlib.EasyDict(pickle.load(f))
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
+ obj.__dict__.update(s)
+ return obj
+
+#----------------------------------------------------------------------------
+
+class ProgressMonitor:
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
+ self.tag = tag
+ self.num_items = num_items
+ self.verbose = verbose
+ self.flush_interval = flush_interval
+ self.progress_fn = progress_fn
+ self.pfn_lo = pfn_lo
+ self.pfn_hi = pfn_hi
+ self.pfn_total = pfn_total
+ self.start_time = time.time()
+ self.batch_time = self.start_time
+ self.batch_items = 0
+ if self.progress_fn is not None:
+ self.progress_fn(self.pfn_lo, self.pfn_total)
+
+ def update(self, cur_items):
+ assert (self.num_items is None) or (cur_items <= self.num_items)
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
+ return
+ cur_time = time.time()
+ total_time = cur_time - self.start_time
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
+ if (self.verbose) and (self.tag is not None):
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
+ self.batch_time = cur_time
+ self.batch_items = cur_items
+
+ if (self.progress_fn is not None) and (self.num_items is not None):
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
+
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
+ return ProgressMonitor(
+ tag = tag,
+ num_items = num_items,
+ flush_interval = flush_interval,
+ verbose = self.verbose,
+ progress_fn = self.progress_fn,
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
+ pfn_total = self.pfn_total,
+ )
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ if data_loader_kwargs is None:
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
+
+ # Try to lookup from cache.
+ cache_file = None
+ if opts.cache:
+ # Choose cache file name.
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
+
+ # Check if the file exists (all processes must agree).
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
+ if opts.num_gpus > 1:
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
+ torch.distributed.broadcast(tensor=flag, src=0)
+ flag = (float(flag.cpu()) != 0)
+
+ # Load.
+ if flag:
+ return FeatureStats.load(cache_file)
+
+ # Initialize.
+ num_items = len(dataset)
+ if max_items is not None:
+ num_items = min(num_items, max_items)
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
+ for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+ features = detector(images.to(opts.device), **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+
+ # Save to cache.
+ if cache_file is not None and opts.rank == 0:
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ temp_file = cache_file + '.' + uuid.uuid4().hex
+ stats.save(temp_file)
+ os.replace(temp_file, cache_file) # atomic
+ return stats
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs):
+ if batch_gen is None:
+ batch_gen = min(batch_size, 4)
+ assert batch_size % batch_gen == 0
+
+ # Setup generator and labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
+
+ # Initialize.
+ stats = FeatureStats(**stats_kwargs)
+ assert stats.max_items is not None
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ while not stats.is_full():
+ images = []
+ for _i in range(batch_size // batch_gen):
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
+ img = G(z=z, c=next(c_iter), **opts.G_kwargs)['image']
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ images.append(img)
+ images = torch.cat(images)
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+ features = detector(images, **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+ return stats
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/metrics/perceptual_path_length.py b/ThirdParty/eg3d/metrics/perceptual_path_length.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e58dac3317733e2ace6d64ee1f97cafa0a38225
--- /dev/null
+++ b/ThirdParty/eg3d/metrics/perceptual_path_length.py
@@ -0,0 +1,127 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
+Architecture for Generative Adversarial Networks". Matches the original
+implementation by Karras et al. at
+https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
+
+import copy
+import numpy as np
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+# Spherical interpolation of a batch of vectors.
+def slerp(a, b, t):
+ a = a / a.norm(dim=-1, keepdim=True)
+ b = b / b.norm(dim=-1, keepdim=True)
+ d = (a * b).sum(dim=-1, keepdim=True)
+ p = t * torch.acos(d)
+ c = b - d * a
+ c = c / c.norm(dim=-1, keepdim=True)
+ d = a * torch.cos(p) + c * torch.sin(p)
+ d = d / d.norm(dim=-1, keepdim=True)
+ return d
+
+#----------------------------------------------------------------------------
+
+class PPLSampler(torch.nn.Module):
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
+ assert space in ['z', 'w']
+ assert sampling in ['full', 'end']
+ super().__init__()
+ self.G = copy.deepcopy(G)
+ self.G_kwargs = G_kwargs
+ self.epsilon = epsilon
+ self.space = space
+ self.sampling = sampling
+ self.crop = crop
+ self.vgg16 = copy.deepcopy(vgg16)
+
+ def forward(self, c):
+ # Generate random latents and interpolation t-values.
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
+
+ # Interpolate in W or Z.
+ if self.space == 'w':
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
+ else: # space == 'z'
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
+
+ # Randomize noise buffers.
+ for name, buf in self.G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Generate images.
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
+
+ # Center crop.
+ if self.crop:
+ assert img.shape[2] == img.shape[3]
+ c = img.shape[2] // 8
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
+
+ # Downsample to 256x256.
+ factor = self.G.img_resolution // 256
+ if factor > 1:
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
+
+ # Scale dynamic range from [-1,1] to [0,255].
+ img = (img + 1) * (255 / 2)
+ if self.G.img_channels == 1:
+ img = img.repeat([1, 3, 1, 1])
+
+ # Evaluate differential LPIPS.
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
+ return dist
+
+#----------------------------------------------------------------------------
+
+def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
+ vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
+
+ # Setup sampler and labels.
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
+ sampler.eval().requires_grad_(False).to(opts.device)
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
+
+ # Sampling loop.
+ dist = []
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ x = sampler(next(c_iter))
+ for src in range(opts.num_gpus):
+ y = x.clone()
+ if opts.num_gpus > 1:
+ torch.distributed.broadcast(y, src=src)
+ dist.append(y)
+ progress.update(num_samples)
+
+ # Compute PPL.
+ if opts.rank != 0:
+ return float('nan')
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
+ lo = np.percentile(dist, 1, interpolation='lower')
+ hi = np.percentile(dist, 99, interpolation='higher')
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
+ return float(ppl)
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/metrics/precision_recall.py b/ThirdParty/eg3d/metrics/precision_recall.py
new file mode 100644
index 0000000000000000000000000000000000000000..e33e85f64de81fa211135edaf3863c2fe851a6f4
--- /dev/null
+++ b/ThirdParty/eg3d/metrics/precision_recall.py
@@ -0,0 +1,64 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Precision/Recall (PR) from the paper "Improved Precision and Recall
+Metric for Assessing Generative Models". Matches the original implementation
+by Kynkaanniemi et al. at
+https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
+
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
+ assert 0 <= rank < num_gpus
+ num_cols = col_features.shape[0]
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
+ dist_batches = []
+ for col_batch in col_batches[rank :: num_gpus]:
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
+ for src in range(num_gpus):
+ dist_broadcast = dist_batch.clone()
+ if num_gpus > 1:
+ torch.distributed.broadcast(dist_broadcast, src=src)
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
+
+#----------------------------------------------------------------------------
+
+def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ detector_kwargs = dict(return_features=True)
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
+
+ results = dict()
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
+ kth = []
+ for manifold_batch in manifold.split(row_batch_size):
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
+ kth = torch.cat(kth) if opts.rank == 0 else None
+ pred = []
+ for probes_batch in probes.split(row_batch_size):
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
+ return results['precision'], results['recall']
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/shape_utils.py b/ThirdParty/eg3d/shape_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e16f6cc82a59d9d3e455ba334abf68b576fdc10f
--- /dev/null
+++ b/ThirdParty/eg3d/shape_utils.py
@@ -0,0 +1,124 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+
+"""
+Utils for extracting 3D shapes using marching cubes. Based on code from DeepSDF (Park et al.)
+
+Takes as input an .mrc file and extracts a mesh.
+
+Ex.
+ python shape_utils.py my_shape.mrc
+Ex.
+ python shape_utils.py myshapes_directory --level=12
+"""
+
+
+import time
+import plyfile
+import glob
+import logging
+import numpy as np
+import os
+import random
+import torch
+import torch.utils.data
+import trimesh
+import skimage.measure
+import argparse
+import mrcfile
+from tqdm import tqdm
+
+
+def convert_sdf_samples_to_ply(
+ numpy_3d_sdf_tensor,
+ voxel_grid_origin,
+ voxel_size,
+ ply_filename_out,
+ offset=None,
+ scale=None,
+ level=0.0
+):
+ """
+ Convert sdf samples to .ply
+ :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
+ :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
+ :voxel_size: float, the size of the voxels
+ :ply_filename_out: string, path of the filename to save to
+ This function adapted from: https://github.com/RobotLocomotion/spartan
+ """
+ start_time = time.time()
+
+ verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0)
+ # try:
+ verts, faces, normals, values = skimage.measure.marching_cubes(
+ numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3
+ )
+ # except:
+ # pass
+
+ # transform from voxel coordinates to camera coordinates
+ # note x and y are flipped in the output of marching_cubes
+ mesh_points = np.zeros_like(verts)
+ mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
+ mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
+ mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]
+
+ # apply additional offset and scale
+ if scale is not None:
+ mesh_points = mesh_points / scale
+ if offset is not None:
+ mesh_points = mesh_points - offset
+
+ # try writing to the ply file
+
+ num_verts = verts.shape[0]
+ num_faces = faces.shape[0]
+
+ verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
+
+ for i in range(0, num_verts):
+ verts_tuple[i] = tuple(mesh_points[i, :])
+
+ faces_building = []
+ for i in range(0, num_faces):
+ faces_building.append(((faces[i, :].tolist(),)))
+ faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])
+
+ el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
+ el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
+
+ ply_data = plyfile.PlyData([el_verts, el_faces])
+ ply_data.write(ply_filename_out)
+ print(f"wrote to {ply_filename_out}")
+
+
+def convert_mrc(input_filename, output_filename, isosurface_level=1):
+ with mrcfile.open(input_filename) as mrc:
+ convert_sdf_samples_to_ply(np.transpose(mrc.data, (2, 1, 0)), [0, 0, 0], 1, output_filename, level=isosurface_level)
+
+if __name__ == '__main__':
+ start_time = time.time()
+ parser = argparse.ArgumentParser()
+ parser.add_argument('input_mrc_path')
+ parser.add_argument('--level', type=float, default=10, help="The isosurface level for marching cubes")
+ args = parser.parse_args()
+
+ if os.path.isfile(args.input_mrc_path) and args.input_mrc_path.split('.')[-1] == 'ply':
+ output_obj_path = args.input_mrc_path.split('.mrc')[0] + '.ply'
+ convert_mrc(args.input_mrc_path, output_obj_path, isosurface_level=1)
+
+ print(f"{time.time() - start_time:02f} s")
+ else:
+ assert os.path.isdir(args.input_mrc_path)
+
+ for mrc_path in tqdm(glob.glob(os.path.join(args.input_mrc_path, '*.mrc'))):
+ output_obj_path = mrc_path.split('.mrc')[0] + '.ply'
+ convert_mrc(mrc_path, output_obj_path, isosurface_level=args.level)
\ No newline at end of file
diff --git a/ThirdParty/eg3d/torch_utils/__init__.py b/ThirdParty/eg3d/torch_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/ThirdParty/eg3d/torch_utils/__pycache__/__init__.cpython-310.pyc b/ThirdParty/eg3d/torch_utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..661f7dde7f0e738919bc34eacc21c8ebbd342c11
Binary files /dev/null and b/ThirdParty/eg3d/torch_utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/torch_utils/__pycache__/custom_ops.cpython-310.pyc b/ThirdParty/eg3d/torch_utils/__pycache__/custom_ops.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ff9cedc2f46d31d5d333222afd19b843641fb83
Binary files /dev/null and b/ThirdParty/eg3d/torch_utils/__pycache__/custom_ops.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/torch_utils/__pycache__/misc.cpython-310.pyc b/ThirdParty/eg3d/torch_utils/__pycache__/misc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b6ea02a1dcb1d5f87ce09beb9f33c044c339d87
Binary files /dev/null and b/ThirdParty/eg3d/torch_utils/__pycache__/misc.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/torch_utils/__pycache__/persistence.cpython-310.pyc b/ThirdParty/eg3d/torch_utils/__pycache__/persistence.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c5af3cb936b8a5f48c947cc545279e082235e38
Binary files /dev/null and b/ThirdParty/eg3d/torch_utils/__pycache__/persistence.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/torch_utils/custom_ops.py b/ThirdParty/eg3d/torch_utils/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed2524f47ab3d5b8750cfb868cc14012f424acc8
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/custom_ops.py
@@ -0,0 +1,159 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import glob
+import hashlib
+import importlib
+import os
+import re
+import shutil
+import uuid
+
+import torch
+import torch.utils.cpp_extension
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+#----------------------------------------------------------------------------
+
+def _get_mangled_gpu_name():
+ name = torch.cuda.get_device_name().lower()
+ out = []
+ for c in name:
+ if re.match('[a-z0-9_-]+', c):
+ out.append(c)
+ else:
+ out.append('-')
+ return ''.join(out)
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+ if headers is None:
+ headers = []
+ if source_dir is not None:
+ sources = [os.path.join(source_dir, fname) for fname in sources]
+ headers = [os.path.join(source_dir, fname) for fname in headers]
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
+ verbose_build = (verbosity == 'full')
+
+ # Compile and load.
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
+ # break the build or unnecessarily restrict what's available to nvcc.
+ # Unset it to let nvcc decide based on what's available on the
+ # machine.
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ #
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
+ # around the *.cu dependency bug in ninja config.
+ #
+ all_source_files = sorted(sources + headers)
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
+
+ # Compute combined hash digest for all source files.
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+
+ # Select cached build directory name.
+ source_digest = hash_md5.hexdigest()
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
+
+ if not os.path.isdir(cached_build_dir):
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
+ os.makedirs(tmpdir)
+ for src in all_source_files:
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
+ try:
+ os.replace(tmpdir, cached_build_dir) # atomic
+ except OSError:
+ # source directory already exists, delete tmpdir and its contents.
+ shutil.rmtree(tmpdir)
+ if not os.path.isdir(cached_build_dir): raise
+
+ # Compile.
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+
+ # Load.
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache dict.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/misc.py b/ThirdParty/eg3d/torch_utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f15d37235fcf5458b27302c278209754bc83965
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/misc.py
@@ -0,0 +1,268 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+import ThirdParty.eg3d.dnnlib
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to temporarily suppress known warnings in torch.jit.trace().
+# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
+
+@contextlib.contextmanager
+def suppress_tracer_warnings():
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
+ warnings.filters.insert(0, flt)
+ yield
+ warnings.filters.remove(flt)
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+def assert_shape(tensor, ref_shape):
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
+ elif size != ref_size:
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+ decorator.__name__ = fn.__name__
+ return decorator
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = dict(named_params_and_buffers(src_module))
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+#----------------------------------------------------------------------------
+# Check DistributedDataParallel consistency across processes.
+
+def check_ddp_consistency(module, ignore_regex=None):
+ assert isinstance(module, torch.nn.Module)
+ for name, tensor in named_params_and_buffers(module):
+ fullname = type(module).__name__ + '.' + name
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
+ continue
+ tensor = tensor.detach()
+ if tensor.is_floating_point():
+ tensor = nan_to_num(tensor)
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (tensor == other).all(), fullname
+
+#----------------------------------------------------------------------------
+# Print summary table of module hierarchy.
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+ def post_hook(mod, _inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+
+ # Construct table.
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/__init__.py b/ThirdParty/eg3d/torch_utils/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/ThirdParty/eg3d/torch_utils/ops/__pycache__/__init__.cpython-310.pyc b/ThirdParty/eg3d/torch_utils/ops/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cebdd06b975c8a96eb7ac2b5d011646686d6e2f9
Binary files /dev/null and b/ThirdParty/eg3d/torch_utils/ops/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/torch_utils/ops/__pycache__/bias_act.cpython-310.pyc b/ThirdParty/eg3d/torch_utils/ops/__pycache__/bias_act.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee5016e66dc88b0e44f4effd84501c008974783a
Binary files /dev/null and b/ThirdParty/eg3d/torch_utils/ops/__pycache__/bias_act.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-310.pyc b/ThirdParty/eg3d/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..180dd45663d68ff402594da070f529257e5e0661
Binary files /dev/null and b/ThirdParty/eg3d/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/torch_utils/ops/__pycache__/conv2d_resample.cpython-310.pyc b/ThirdParty/eg3d/torch_utils/ops/__pycache__/conv2d_resample.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44dd2d08423f4c022cc4a4dc30d739e73cec45ab
Binary files /dev/null and b/ThirdParty/eg3d/torch_utils/ops/__pycache__/conv2d_resample.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/torch_utils/ops/__pycache__/fma.cpython-310.pyc b/ThirdParty/eg3d/torch_utils/ops/__pycache__/fma.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b02e77fcc850ee3bfbf9c2d74b02144eff7fecb
Binary files /dev/null and b/ThirdParty/eg3d/torch_utils/ops/__pycache__/fma.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/torch_utils/ops/__pycache__/upfirdn2d.cpython-310.pyc b/ThirdParty/eg3d/torch_utils/ops/__pycache__/upfirdn2d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e534a4f1f89f726d86099adda068ba10b6c955ca
Binary files /dev/null and b/ThirdParty/eg3d/torch_utils/ops/__pycache__/upfirdn2d.cpython-310.pyc differ
diff --git a/ThirdParty/eg3d/torch_utils/ops/bias_act.cpp b/ThirdParty/eg3d/torch_utils/ops/bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ee6f6d0caaf4f84b94851d223e384344e1109cdc
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/bias_act.cpp
@@ -0,0 +1,103 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/bias_act.cu b/ThirdParty/eg3d/torch_utils/ops/bias_act.cu
new file mode 100644
index 0000000000000000000000000000000000000000..71ca3900deda41e62d80044f0e409875f4c794b5
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/bias_act.cu
@@ -0,0 +1,177 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/bias_act.h b/ThirdParty/eg3d/torch_utils/ops/bias_act.h
new file mode 100644
index 0000000000000000000000000000000000000000..8994bfb4e9cae790865348e08de5f685152d3344
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/bias_act.h
@@ -0,0 +1,42 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/bias_act.py b/ThirdParty/eg3d/torch_utils/ops/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46ca82fed202efe31b615698981c76d935f9e72
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/bias_act.py
@@ -0,0 +1,211 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import numpy as np
+import torch
+from ThirdParty.eg3d import dnnlib
+
+from .. import custom_ops
+from .. import misc
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+_null_tensor = torch.empty([0])
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='bias_act_plugin',
+ sources=['bias_act.cpp', 'bias_act.cu'],
+ headers=['bias_act.h'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+#----------------------------------------------------------------------------
+
+def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/conv2d_gradfix.py b/ThirdParty/eg3d/torch_utils/ops/conv2d_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a177cc1c0b6eabf16908cf9afaa4387e7716b72
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/conv2d_gradfix.py
@@ -0,0 +1,199 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.conv2d` that supports
+arbitrarily high order gradients with zero performance penalty."""
+
+import contextlib
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+
+@contextlib.contextmanager
+def no_weight_gradients(disable=True):
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ if disable:
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+#----------------------------------------------------------------------------
+
+def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if input.device.type != 'cuda':
+ return False
+ return True
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+#----------------------------------------------------------------------------
+
+_conv2d_gradfix_cache = dict()
+_null_tensor = torch.empty([0])
+
+def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
+ # Parse arguments.
+ ndim = 2
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
+ if key in _conv2d_gradfix_cache:
+ return _conv2d_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert len(weight_shape) == ndim + 2
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [0, 0]
+ return [
+ input_shape[i + 2]
+ - (output_shape[i + 2] - 1) * stride[i]
+ - (1 - 2 * padding[i])
+ - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class Conv2d(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ assert weight.shape == weight_shape
+ ctx.save_for_backward(
+ input if weight.requires_grad else _null_tensor,
+ weight if input.requires_grad else _null_tensor,
+ )
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ if transpose:
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ input_shape = ctx.input_shape
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
+ grad_input = op.apply(grad_output, weight, None)
+ assert grad_input.shape == input_shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
+ grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum([0, 2, 3])
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class Conv2dGradWeight(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, weight):
+ ctx.save_for_backward(
+ grad_output if input.requires_grad else _null_tensor,
+ input if grad_output.requires_grad else _null_tensor,
+ )
+ ctx.grad_output_shape = grad_output.shape
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
+
+
+ @staticmethod
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad_output_shape = ctx.grad_output_shape
+ input_shape = ctx.input_shape
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
+ assert grad2_grad_output.shape == grad_output_shape
+
+ if ctx.needs_input_grad[1]:
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input_shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv2d_gradfix_cache[key] = Conv2d
+ return Conv2d
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/conv2d_resample.py b/ThirdParty/eg3d/torch_utils/ops/conv2d_resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..d46f4ddd85606b9032d08efe3556ecad4676cee5
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/conv2d_resample.py
@@ -0,0 +1,145 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""2D convolution with optional up/downsampling."""
+
+import torch
+
+from .. import misc
+from . import conv2d_gradfix
+from . import upfirdn2d
+from .upfirdn2d import _parse_padding
+from .upfirdn2d import _get_filter_size
+
+#----------------------------------------------------------------------------
+
+def _get_weight_shape(w):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ shape = [int(sz) for sz in w.shape]
+ misc.assert_shape(w, shape)
+ return shape
+
+#----------------------------------------------------------------------------
+
+def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
+ """
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
+
+ # Flip weight if requested.
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ if not flip_weight and (kw > 1 or kh > 1):
+ w = w.flip([2, 3])
+
+ # Execute using conv2d_gradfix.
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
+ return op(x, w, stride=stride, padding=padding, groups=groups)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
+ r"""2D convolution with optional up/downsampling.
+
+ Padding is performed only once at the beginning, not between the operations.
+
+ Args:
+ x: Input tensor of shape
+ `[batch_size, in_channels, in_height, in_width]`.
+ w: Weight tensor of shape
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
+ calling upfirdn2d.setup_filter(). None = identity (default).
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ groups: Split input channels into N groups (default: 1).
+ flip_weight: False = convolution, True = correlation (default: True).
+ flip_filter: False = convolution, True = correlation (default: False).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
+ assert isinstance(up, int) and (up >= 1)
+ assert isinstance(down, int) and (down >= 1)
+ assert isinstance(groups, int) and (groups >= 1)
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+ fw, fh = _get_filter_size(f)
+ px0, px1, py0, py1 = _parse_padding(padding)
+
+ # Adjust padding to account for up/downsampling.
+ if up > 1:
+ px0 += (fw + up - 1) // 2
+ px1 += (fw - up) // 2
+ py0 += (fh + up - 1) // 2
+ py1 += (fh - up) // 2
+ if down > 1:
+ px0 += (fw - down + 1) // 2
+ px1 += (fw - down) // 2
+ py0 += (fh - down + 1) // 2
+ py1 += (fh - down) // 2
+
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ return x
+
+ # Fast path: downsampling only => use strided convolution.
+ if down > 1 and up == 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
+ if up > 1:
+ if groups == 1:
+ w = w.transpose(0, 1)
+ else:
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
+ w = w.transpose(1, 2)
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
+ px0 -= kw - 1
+ px1 -= kw - up
+ py0 -= kh - 1
+ py1 -= kh - up
+ pxt = max(min(-px0, -px1), 0)
+ pyt = max(min(-py0, -py1), 0)
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
+ if up == 1 and down == 1:
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
+
+ # Fallback: Generic reference implementation.
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.cpp b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4f55466235a020b0f5e150350bfdcd8b2a1e579d
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.cpp
@@ -0,0 +1,304 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "filtered_lrelu.h"
+
+//------------------------------------------------------------------------
+
+static std::tuple filtered_lrelu(
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
+
+ // Figure out how much shared memory is available on the device.
+ int maxSharedBytes = 0;
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
+ int sharedKB = maxSharedBytes >> 10;
+
+ // Populate enough launch parameters to check if a CUDA kernel exists.
+ filtered_lrelu_kernel_params p;
+ p.up = up;
+ p.down = down;
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ if (!test_spec.exec)
+ {
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
+ }
+
+ // Input/output element size.
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
+
+ // Input sizes.
+ int64_t xw = (int)x.size(3);
+ int64_t xh = (int)x.size(2);
+ int64_t fut_w = (int)fu.size(-1) - 1;
+ int64_t fut_h = (int)fu.size(0) - 1;
+ int64_t fdt_w = (int)fd.size(-1) - 1;
+ int64_t fdt_h = (int)fd.size(0) - 1;
+
+ // Logical size of upsampled buffer.
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
+
+ // Compute output size and allocate.
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
+
+ // Allocate sign tensor.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ int64_t sw_active = 0; // Active width of sign tensor.
+ if (writeSigns)
+ {
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+ else if (readSigns)
+ sw_active = s.size(3) << 2;
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
+ }
+
+ // Populate rest of CUDA kernel parameters.
+ p.x = x.data_ptr();
+ p.y = y.data_ptr();
+ p.b = b.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.fu = fu.data_ptr();
+ p.fd = fd.data_ptr();
+ p.pad0 = make_int2(px0, py0);
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.flip = (flip_filters) ? 1 : 0;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
+
+ // x, y, b strides are in bytes.
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
+ p.bStride = sz * b.stride(0);
+
+ // fu, fd strides are in elements.
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
+
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
+ bool index64b = false;
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
+ if (s.numel() > INT_MAX) index64b = true;
+
+ // Choose CUDA kernel.
+ filtered_lrelu_kernel_spec spec = { 0 };
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
+ {
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
+ {
+ // Choose kernel based on index type, datatype and sign read/write modes.
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ }
+ });
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = spec.numWarps * 32;
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
+ int gz = p.yShape.z * p.yShape.w;
+
+ // Repeat multiple horizontal tiles in a CTA?
+ if (spec.xrep)
+ {
+ p.tilesXrep = spec.xrep;
+ p.tilesXdim = gx;
+
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
+ std::swap(gx, gy);
+ }
+ else
+ {
+ p.tilesXrep = 0;
+ p.tilesXdim = 0;
+ }
+
+ // Launch filter setup kernel.
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
+
+ // Copy kernels to constant memory.
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+
+ // Set cache and shared memory configurations for main kernel.
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
+
+ // Launch main kernel.
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
+ {
+ p.blockZofs = zofs;
+ int subGz = std::min(maxSubGz, gz - zofs);
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
+ }
+
+ // Done.
+ return std::make_tuple(y, so, 0);
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
+
+ // Output signs if we don't have sign input.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ if (writeSigns)
+ {
+ int64_t sw = x.size(3);
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
+ }
+
+ // Initialize CUDA kernel parameters.
+ filtered_lrelu_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+
+ // Choose CUDA kernel.
+ void* func = 0;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
+ {
+ if (writeSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else if (readSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else
+ func = choose_filtered_lrelu_act_kernel();
+ });
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = 128; // 4 warps per block.
+
+ // Logical size of launch = writeSigns ? p.s : p.x
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
+ gx = (gx - 1) / bx + 1;
+
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
+ const uint32_t gmax = 65535;
+ gy = std::min(gy, gmax);
+ gz = std::min(gz, gmax);
+
+ // Launch.
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
+ return so;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
+}
+
+//------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.cu b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..aaac95408365f023ffaa4cb89348d499d3b948f0
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.cu
@@ -0,0 +1,1288 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "filtered_lrelu.h"
+#include
+
+//------------------------------------------------------------------------
+// Helpers.
+
+enum // Filter modes.
+{
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
+};
+
+template struct InternalType;
+template <> struct InternalType
+{
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+
+#define MIN(A, B) ((A) < (B) ? (A) : (B))
+#define MAX(A, B) ((A) > (B) ? (A) : (B))
+#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
+
+// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
+template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
+{
+ if ((N & (N-1)) && N <= 256)
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
+ else
+ y = i/N;
+
+ x = i - y*N;
+}
+
+// Type cast stride before reading it.
+template __device__ __forceinline__ T get_stride(const int64_t& x)
+{
+ return *reinterpret_cast(&x);
+}
+
+//------------------------------------------------------------------------
+// Filters, setup kernel, copying function.
+
+#define MAX_FILTER_SIZE 32
+
+// Combined up/down filter buffers so that transfer can be done with one copy.
+__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
+__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
+
+// Accessors to combined buffers to index up/down filters individually.
+#define c_fu (c_fbuf)
+#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+#define g_fu (g_fbuf)
+#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+
+// Set up filters into global memory buffer.
+static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
+{
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
+ {
+ int x, y;
+ fast_div_mod(x, y, idx);
+
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
+ if (p.fuShape.y > 0)
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
+ else
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
+
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
+ if (p.fdShape.y > 0)
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
+ else
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
+ }
+}
+
+// Host function to copy filters written by setup kernel into constant buffer for main kernel.
+template static cudaError_t copy_filters(cudaStream_t stream)
+{
+ void* src = 0;
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
+ if (err) return err;
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
+}
+
+//------------------------------------------------------------------------
+// Coordinate spaces:
+// - Relative to input tensor: inX, inY, tileInX, tileInY
+// - Relative to input tile: relInX, relInY, tileInW, tileInH
+// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
+// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
+// - Relative to output tensor: outX, outY, tileOutX, tileOutY
+//
+// Relationships between coordinate spaces:
+// - inX = tileInX + relInX
+// - inY = tileInY + relInY
+// - relUpX = relInX * up + phaseInX
+// - relUpY = relInY * up + phaseInY
+// - relUpX = relOutX * down
+// - relUpY = relOutY * down
+// - outX = tileOutX + relOutX
+// - outY = tileOutY + relOutY
+
+extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
+
+template
+static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
+{
+ // Check that we don't try to support non-existing filter modes.
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
+
+ // Static definitions.
+ typedef typename InternalType::scalar_t scalar_t;
+ typedef typename InternalType::vec2_t vec2_t;
+ typedef typename InternalType::vec4_t vec4_t;
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
+
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
+
+ // Sizes of logical buffers.
+ const int szIn = tileInH_up * tileInW;
+ const int szUpX = tileInH_up * tileUpW;
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
+ const int szDownX = tileUpH * tileOutW;
+
+ // Sizes for shared memory arrays.
+ const int s_buf0_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUFD) ? szIn :
+ -1;
+ const int s_buf1_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
+ (filterMode == MODE_FUSD) ? szUpXY :
+ (filterMode == MODE_SUFD) ? szUpX :
+ (filterMode == MODE_FUFD) ? szUpXY :
+ -1;
+
+ // Ensure U128 alignment.
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
+
+ // Check at compile time that we don't use too much shared memory.
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
+
+ // Declare shared memory arrays.
+ scalar_t* s_buf0;
+ scalar_t* s_buf1;
+ if (sharedKB <= 48)
+ {
+ // Allocate shared memory arrays here.
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
+ s_buf0 = s_buf0_st;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+ else
+ {
+ // Use the dynamically allocated shared memory array.
+ s_buf0 = (scalar_t*)s_buf_raw;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+
+ // Pointers to the buffers.
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
+ if (filterMode == MODE_SUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ s_tileDownX = s_buf1;
+ }
+ else if (filterMode == MODE_FUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ s_tileDownX = s_buf0;
+ }
+ else if (filterMode == MODE_SUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ }
+ else if (filterMode == MODE_FUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ }
+
+ // Allow large grids in z direction via per-launch offset.
+ int channelIdx = blockIdx.z + p.blockZofs;
+ int batchIdx = channelIdx / p.yShape.z;
+ channelIdx -= batchIdx * p.yShape.z;
+
+ // Offset to output feature map. In bytes.
+ index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w);
+
+ // Sign shift amount.
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
+
+ // Inner tile loop.
+ #pragma unroll 1
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
+ {
+ // Locate output tile.
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
+ int tileOutX = tileX * tileOutW;
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
+
+ // Locate input tile.
+ int tmpX = tileOutX * down - p.pad0.x;
+ int tmpY = tileOutY * down - p.pad0.y;
+ int tileInX = CEIL_DIV(tmpX, up);
+ int tileInY = CEIL_DIV(tmpY, up);
+ const int phaseInX = tileInX * up - tmpX;
+ const int phaseInY = tileInY * up - tmpY;
+
+ // Extra sync if input and output buffers are the same and we are not on first tile.
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
+ __syncthreads();
+
+ // Load input tile & apply bias. Unrolled.
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride)));
+ index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w);
+ int idx = threadIdx.x;
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
+ #pragma unroll
+ for (int loop = 0; loop < loopCountIN; loop++)
+ {
+ int relInX, relInY;
+ fast_div_mod(relInX, relInY, idx);
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b;
+
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
+ if (!skip)
+ s_tileIn[idx] = v;
+
+ idx += threadsPerBlock;
+ }
+
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
+ {
+ // Horizontal upsampling.
+ __syncthreads();
+ if (up == 4)
+ {
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ scalar_t a = s_tileIn[src0];
+ if (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInX == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInX == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ s_tileUpX[dst+2] = v.z;
+ s_tileUpX[dst+3] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ bool p0 = (phaseInX == 0);
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ scalar_t a = s_tileIn[src0];
+ if (p0) // (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ }
+ }
+
+ // Vertical upsampling & nonlinearity.
+
+ __syncthreads();
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
+ if (up == 4)
+ {
+ minY -= 3; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec4_t v = InternalType::zero_vec4();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInY == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInY == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+ index_t si2 = si0 + p.sShape.x * 2;
+ index_t si3 = si0 + p.sShape.x * 3;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ int ss = (signX & 3) << 1;
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ minY -= 1; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec2_t v = InternalType::zero_vec2();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ int s = sx + sy;
+ s <<= signXo;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+
+ // Combine signs.
+ int s = sx + sy;
+ s <<= signXo;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ }
+
+ if (!downInline)
+ {
+ // Write into temporary buffer.
+ s_tileUpXY[dst] = v.x;
+ if (relUpY0 < tileUpH - 1)
+ s_tileUpXY[dst + tileUpW] = v.y;
+ }
+ else
+ {
+ // Write directly into output buffer.
+ if ((uint32_t)x < p.yShape.x)
+ {
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
+ index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut;
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
+ }
+ }
+ }
+ }
+ }
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
+ {
+ // Full upsampling filter.
+
+ if (up == 2)
+ {
+ // 2 x 2-wide.
+ __syncthreads();
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
+ {
+ int relUpX0, relUpY0;
+ fast_div_mod(relUpX0, relUpY0, idx);
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
+ int src0 = relInX0 + tileInW * relInY0;
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
+
+ #define X_LOOP(TAPY, PX) \
+ for (int sx = 0; sx < fuSize / up; sx++) \
+ { \
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
+ }
+
+ vec4_t v = InternalType::zero_vec4();
+ if (tap0y == 0 && phaseInX == 0)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(0, 0) }
+ if (tap0y == 0 && phaseInX == 1)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(0, 1) }
+ if (tap0y == 1 && phaseInX == 0)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(1, 0) }
+ if (tap0y == 1 && phaseInX == 1)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(1, 1) }
+
+ #undef X_LOOP
+
+ int x = tileOutX * down + relUpX0;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31;
+ int sy = __float_as_uint(v.y) >> 31;
+ int sz = __float_as_uint(v.z) >> 31;
+ int sw = __float_as_uint(v.w) >> 31;
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31;
+ int sy = __float_as_uint(v.y) >> 31;
+ int sz = __float_as_uint(v.z) >> 31;
+ int sw = __float_as_uint(v.w) >> 31;
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read sign and apply.
+ {
+ if ((uint32_t)signY < p.sShape.y)
+ {
+ int s = 0;
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
+ s >>= (signX & 3) << 1;
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[idx + 0] = v.x;
+ s_tileUpXY[idx + 1] = v.y;
+ s_tileUpXY[idx + 2] = v.z;
+ s_tileUpXY[idx + 3] = v.w;
+ }
+ }
+ else if (up == 1)
+ {
+ __syncthreads();
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
+ {
+ int relUpX0, relUpY0;
+ fast_div_mod(relUpX0, relUpY0, idx);
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
+
+ int x = tileOutX * down + relUpX0;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ v *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write sign.
+ uint32_t s = 0;
+ uint32_t signXbit = (1u << signXo);
+ if (v < 0.f)
+ {
+ s = signXbit;
+ v *= p.slope;
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ s = signXbit * 2;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
+ p.s[si] = s; // Write.
+ }
+ }
+ else
+ {
+ // Determine and write sign.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ uint32_t s = 0;
+ uint32_t signXbit = (1u << signXo);
+ if (v < 0.f)
+ {
+ s = signXbit;
+ v *= p.slope;
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ s = signXbit * 2;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
+ p.s[si] = s; // Write.
+ }
+ else
+ {
+ // Just compute the value.
+ if (v < 0.f) v *= p.slope;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ }
+ }
+ else if (signRead)
+ {
+ // Read sign and apply if within sign tensor bounds.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
+ {
+ int s = p.s[si];
+ s >>= signXo;
+ if (s & 1) v *= p.slope;
+ if (s & 2) v = 0.f;
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v < 0.f) v *= p.slope;
+ v = InternalType::clamp(v, p.clamp);
+ }
+
+ if (!downInline) // Write into temporary buffer.
+ s_tileUpXY[idx] = v;
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
+ *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
+ }
+ }
+ }
+
+ // Downsampling.
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
+ {
+ // Horizontal downsampling.
+ __syncthreads();
+ if (down == 4 && tileOutW % 4 == 0)
+ {
+ // Calculate 4 pixels at a time.
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src0 = relUpY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
+ }
+ s_tileDownX[idx+0] = v.x;
+ s_tileDownX[idx+1] = v.y;
+ s_tileDownX[idx+2] = v.z;
+ s_tileDownX[idx+3] = v.w;
+ }
+ }
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
+ {
+ // Calculate 2 pixels at a time.
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src0 = relUpY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
+ }
+ s_tileDownX[idx+0] = v.x;
+ s_tileDownX[idx+1] = v.y;
+ }
+ }
+ else
+ {
+ // Calculate 1 pixel at a time.
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src = relUpY * tileUpW + relUpX0;
+ scalar_t v = 0.f;
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
+ s_tileDownX[idx] = v;
+ }
+ }
+
+ // Vertical downsampling & store output tile.
+ __syncthreads();
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
+ {
+ int relOutX, relOutY0;
+ fast_div_mod(relOutX, relOutY0, idx);
+ int relUpY0 = relOutY0 * down;
+ int src0 = relUpY0 * tileOutW + relOutX;
+ scalar_t v = 0;
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
+
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY0;
+
+ if (outX < p.yShape.x & outY < p.yShape.y)
+ *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v;
+ }
+ }
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
+ {
+ // Full downsampling filter.
+ if (down == 2)
+ {
+ // 2-wide.
+ __syncthreads();
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
+ {
+ int relOutX0, relOutY0;
+ fast_div_mod(relOutX0, relOutY0, idx);
+ int relUpX0 = relOutX0 * down;
+ int relUpY0 = relOutY0 * down;
+ int src0 = relUpY0 * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ #pragma unroll
+ for (int sy = 0; sy < fdSize; sy++)
+ #pragma unroll
+ for (int sx = 0; sx < fdSize; sx++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+ }
+
+ int outX = tileOutX + relOutX0;
+ int outY = tileOutY + relOutY0;
+ if ((uint32_t)outY < p.yShape.y)
+ {
+ index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut;
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y;
+ }
+ }
+ }
+ else if (down == 1 && !downInline)
+ {
+ // Thread per pixel.
+ __syncthreads();
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
+ {
+ int relOutX0, relOutY0;
+ fast_div_mod(relOutX0, relOutY0, idx);
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
+
+ int outX = tileOutX + relOutX0;
+ int outY = tileOutY + relOutY0;
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
+ *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v;
+ }
+ }
+ }
+
+ if (!enableXrep)
+ break;
+ }
+}
+
+//------------------------------------------------------------------------
+// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
+// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
+
+template
+static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Indexing.
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
+
+ // Loop to accommodate oversized tensors.
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
+ {
+ // Extract z and w (channel, minibatch index).
+ int32_t w = q / p.xShape.z;
+ int32_t z = q - w * p.xShape.z;
+
+ // Choose behavior based on sign read/write mode.
+ if (signWrite)
+ {
+ // Process value if in p.x.
+ uint32_t s = 0;
+ if (x < p.xShape.x && y < p.xShape.y)
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+
+ // Gain, LReLU, clamp.
+ v *= p.gain;
+ if (v < 0.f)
+ {
+ v *= p.slope;
+ s = 1; // Sign.
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ v = InternalType::clamp(v, p.clamp);
+ s = 2; // Clamp.
+ }
+
+ *pv = (T)v; // Write value.
+ }
+
+ // Coalesce into threads 0 and 16 of warp.
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
+ s |= __shfl_xor_sync(m, s, 2);
+ s |= __shfl_xor_sync(m, s, 4);
+ s |= __shfl_xor_sync(m, s, 8);
+
+ // Write signs if leader and in p.s.
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
+ {
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
+ ((uint32_t*)p.s)[is >> 4] = s;
+ }
+ }
+ else if (signRead)
+ {
+ // Process value if in p.x.
+ if (x < p.xShape.x) // y is always in.
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+ v *= p.gain;
+
+ // Apply sign buffer offset.
+ uint32_t sx = x + p.sOfs.x;
+ uint32_t sy = y + p.sOfs.y;
+
+ // Read and apply signs if we land inside valid region of sign buffer.
+ if (sx < p.sShape.x && sy < p.sShape.y)
+ {
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
+ unsigned char s = p.s[is];
+ s >>= (sx & 3) << 1; // Shift into place.
+ if (s & 1) // Sign?
+ v *= p.slope;
+ if (s & 2) // Clamp?
+ v = 0.f;
+ }
+
+ *pv = (T)v; // Write value.
+ }
+ }
+ else
+ {
+ // Forward pass with no sign write. Process value if in p.x.
+ if (x < p.xShape.x) // y is always in.
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+ v *= p.gain;
+ if (v < 0.f)
+ v *= p.slope;
+ if (fabsf(v) > p.clamp)
+ v = InternalType::clamp(v, p.clamp);
+ *pv = (T)v; // Write value.
+ }
+ }
+ }
+}
+
+template void* choose_filtered_lrelu_act_kernel(void)
+{
+ return (void*)filtered_lrelu_act_kernel;
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
+{
+ filtered_lrelu_kernel_spec s = { 0 };
+
+ // Return the first matching kernel.
+#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
+ if (sharedKB >= SH) \
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
+ { \
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
+ s.setup = (void*)setup_filters_kernel; \
+ s.exec = (void*)filtered_lrelu_kernel; \
+ s.tileOut = make_int2(TW, TH); \
+ s.numWarps = W; \
+ s.xrep = XR; \
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
+ return s; \
+ }
+
+ // Launch parameters for various kernel specializations.
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
+
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
+
+ #undef CASE
+ return s; // No kernel found.
+}
+
+//------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.h b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.h
new file mode 100644
index 0000000000000000000000000000000000000000..f2bfd1dd537909de9cd3b14765a482056391683b
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.h
@@ -0,0 +1,94 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct filtered_lrelu_kernel_params
+{
+ // These parameters decide which kernel to use.
+ int up; // upsampling ratio (1, 2, 4)
+ int down; // downsampling ratio (1, 2, 4)
+ int2 fuShape; // [size, 1] | [size, size]
+ int2 fdShape; // [size, 1] | [size, size]
+
+ int _dummy; // Alignment.
+
+ // Rest of the parameters.
+ const void* x; // Input tensor.
+ void* y; // Output tensor.
+ const void* b; // Bias tensor.
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
+ const float* fu; // Upsampling filter.
+ const float* fd; // Downsampling filter.
+
+ int2 pad0; // Left/top padding.
+ float gain; // Additional gain factor.
+ float slope; // Leaky ReLU slope on negative side.
+ float clamp; // Clamp after nonlinearity.
+ int flip; // Filter kernel flip for gradient computation.
+
+ int tilesXdim; // Original number of horizontal output tiles.
+ int tilesXrep; // Number of horizontal tiles per CTA.
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
+
+ int4 xShape; // [width, height, channel, batch]
+ int4 yShape; // [width, height, channel, batch]
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+ int swLimit; // Active width of sign tensor in bytes.
+
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
+ longlong4 yStride; //
+ int64_t bStride; //
+ longlong3 fuStride; //
+ longlong3 fdStride; //
+};
+
+struct filtered_lrelu_act_kernel_params
+{
+ void* x; // Input/output, modified in-place.
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
+
+ float gain; // Additional gain factor.
+ float slope; // Leaky ReLU slope on negative side.
+ float clamp; // Clamp after nonlinearity.
+
+ int4 xShape; // [width, height, channel, batch]
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel specialization.
+
+struct filtered_lrelu_kernel_spec
+{
+ void* setup; // Function for filter kernel setup.
+ void* exec; // Function for main operation.
+ int2 tileOut; // Width/height of launch tile.
+ int numWarps; // Number of warps per thread block, determines launch block size.
+ int xrep; // For processing multiple horizontal tiles per thread block.
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template void* choose_filtered_lrelu_act_kernel(void);
+template cudaError_t copy_filters(cudaStream_t stream);
+
+//------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.py b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..2047b7e19320e8d03e444ca1cb03fe00d0c5e96e
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu.py
@@ -0,0 +1,276 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import os
+import numpy as np
+import torch
+import warnings
+
+from .. import custom_ops
+from .. import misc
+from . import upfirdn2d
+from . import bias_act
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='filtered_lrelu_plugin',
+ sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+def _get_filter_size(f):
+ if f is None:
+ return 1, 1
+ assert isinstance(f, torch.Tensor)
+ assert 1 <= f.ndim <= 2
+ return f.shape[-1], f.shape[0] # width, height
+
+def _parse_padding(padding):
+ if isinstance(padding, int):
+ padding = [padding, padding]
+ assert isinstance(padding, (list, tuple))
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
+ padding = [int(x) for x in padding]
+ if len(padding) == 2:
+ px, py = padding
+ padding = [px, px, py, py]
+ px0, px1, py0, py1 = padding
+ return px0, px1, py0, py1
+
+#----------------------------------------------------------------------------
+
+def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
+ r"""Filtered leaky ReLU for a batch of 2D images.
+
+ Performs the following sequence of operations for each channel:
+
+ 1. Add channel-specific bias if provided (`b`).
+
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
+
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
+ Negative padding corresponds to cropping the image.
+
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
+ so that the footprint of all output pixels lies within the input image.
+
+ 5. Multiply each value by the provided gain factor (`gain`).
+
+ 6. Apply leaky ReLU activation function to each value.
+
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
+
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
+ it so that the footprint of all output pixels lies within the input image.
+
+ 9. Downsample the image by keeping every Nth pixel (`down`).
+
+ The fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Float32/float16/float64 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ fu: Float32 upsampling FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ fd: Float32 downsampling FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The length of vector must must match the channel dimension of `x`.
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor. (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
+ flip_filter: False = convolution, True = correlation (default: False).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
+ return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
+ existing `upfirdn2n()` and `bias_act()` ops.
+ """
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ fu_w, fu_h = _get_filter_size(fu)
+ fd_w, fd_h = _get_filter_size(fd)
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
+ misc.assert_shape(b, [x.shape[1]])
+ assert isinstance(up, int) and up >= 1
+ assert isinstance(down, int) and down >= 1
+ px0, px1, py0, py1 = _parse_padding(padding)
+ assert gain == float(gain) and gain > 0
+ assert slope == float(slope) and slope >= 0
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
+
+ # Calculate output size.
+ batch_size, channels, in_h, in_w = x.shape
+ in_dtype = x.dtype
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
+
+ # Compute using existing ops.
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
+ x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
+ x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
+
+ # Check output shape & dtype.
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
+ assert x.dtype == in_dtype
+ return x
+
+#----------------------------------------------------------------------------
+
+_filtered_lrelu_cuda_cache = dict()
+
+def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
+ """
+ assert isinstance(up, int) and up >= 1
+ assert isinstance(down, int) and down >= 1
+ px0, px1, py0, py1 = _parse_padding(padding)
+ assert gain == float(gain) and gain > 0
+ gain = float(gain)
+ assert slope == float(slope) and slope >= 0
+ slope = float(slope)
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
+ clamp = float(clamp if clamp is not None else 'inf')
+
+ # Lookup from cache.
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
+ if key in _filtered_lrelu_cuda_cache:
+ return _filtered_lrelu_cuda_cache[key]
+
+ # Forward op.
+ class FilteredLReluCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
+ if fu is None:
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ if fd is None:
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert 1 <= fu.ndim <= 2
+ assert 1 <= fd.ndim <= 2
+
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
+ fu = fu.square()[None]
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
+ fd = fd.square()[None]
+
+ # Missing sign input tensor.
+ if si is None:
+ si = torch.empty([0])
+
+ # Missing bias tensor.
+ if b is None:
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
+
+ # Construct internal sign tensor only if gradients are needed.
+ write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
+
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
+ warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
+
+ # Call C++/Cuda plugin if datatype is supported.
+ if x.dtype in [torch.float16, torch.float32]:
+ if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
+ warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
+ y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
+ else:
+ return_code = -1
+
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
+ # only the bit-packed sign tensor is retained for gradient computation.
+ if return_code < 0:
+ warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
+
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
+ y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
+ so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
+ y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
+
+ # Prepare for gradient computation.
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
+ ctx.x_shape = x.shape
+ ctx.y_shape = y.shape
+ ctx.s_ofs = sx, sy
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ fu, fd, si = ctx.saved_tensors
+ _, _, xh, xw = ctx.x_shape
+ _, _, yh, yw = ctx.y_shape
+ sx, sy = ctx.s_ofs
+ dx = None # 0
+ dfu = None; assert not ctx.needs_input_grad[1]
+ dfd = None; assert not ctx.needs_input_grad[2]
+ db = None # 3
+ dsi = None; assert not ctx.needs_input_grad[4]
+ dsx = None; assert not ctx.needs_input_grad[5]
+ dsy = None; assert not ctx.needs_input_grad[6]
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
+ pp = [
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
+ xw * up - yw * down + px0 - (up - 1),
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
+ xh * up - yh * down + py0 - (up - 1),
+ ]
+ gg = gain * (up ** 2) / (down ** 2)
+ ff = (not flip_filter)
+ sx = sx - (fu.shape[-1] - 1) + px0
+ sy = sy - (fu.shape[0] - 1) + py0
+ dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
+
+ if ctx.needs_input_grad[3]:
+ db = dx.sum([0, 2, 3])
+
+ return dx, dfu, dfd, db, dsi, dsx, dsy
+
+ # Add to cache.
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
+ return FilteredLReluCuda
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu_ns.cu b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu_ns.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8a3eae46215c3babea2c54e3ae255b05f4d777af
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu_ns.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for no signs mode (no gradients required).
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu_rd.cu b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu_rd.cu
new file mode 100644
index 0000000000000000000000000000000000000000..3cd43ec0648d3db05e5808299fc0ee318e5ceaa6
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu_rd.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for sign read mode.
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu_wr.cu b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu_wr.cu
new file mode 100644
index 0000000000000000000000000000000000000000..bc2fa06912eb703dd77ca64533208428bdf373ac
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/filtered_lrelu_wr.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for sign write mode.
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/ThirdParty/eg3d/torch_utils/ops/fma.py b/ThirdParty/eg3d/torch_utils/ops/fma.py
new file mode 100644
index 0000000000000000000000000000000000000000..5458116d0b6f8b133608456bbe9003aa0283ac85
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/fma.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
+
+import torch
+
+#----------------------------------------------------------------------------
+
+def fma(a, b, c): # => a * b + c
+ return _FusedMultiplyAdd.apply(a, b, c)
+
+#----------------------------------------------------------------------------
+
+class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
+ @staticmethod
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
+ out = torch.addcmul(c, a, b)
+ ctx.save_for_backward(a, b)
+ ctx.c_shape = c.shape
+ return out
+
+ @staticmethod
+ def backward(ctx, dout): # pylint: disable=arguments-differ
+ a, b = ctx.saved_tensors
+ c_shape = ctx.c_shape
+ da = None
+ db = None
+ dc = None
+
+ if ctx.needs_input_grad[0]:
+ da = _unbroadcast(dout * b, a.shape)
+
+ if ctx.needs_input_grad[1]:
+ db = _unbroadcast(dout * a, b.shape)
+
+ if ctx.needs_input_grad[2]:
+ dc = _unbroadcast(dout, c_shape)
+
+ return da, db, dc
+
+#----------------------------------------------------------------------------
+
+def _unbroadcast(x, shape):
+ extra_dims = x.ndim - len(shape)
+ assert extra_dims >= 0
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
+ if len(dim):
+ x = x.sum(dim=dim, keepdim=True)
+ if extra_dims:
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
+ assert x.shape == shape
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/grid_sample_gradfix.py b/ThirdParty/eg3d/torch_utils/ops/grid_sample_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..35d94724136ba162d8416803b1ad00d6da0db99f
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/grid_sample_gradfix.py
@@ -0,0 +1,79 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.grid_sample` that
+supports arbitrarily high order gradients between the input and output.
+Only works on 2D images and assumes
+`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
+
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+
+#----------------------------------------------------------------------------
+
+def grid_sample(input, grid):
+ if _should_use_custom_op():
+ return _GridSample2dForward.apply(input, grid)
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op():
+ return enabled
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dForward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, grid):
+ assert input.ndim == 4
+ assert grid.ndim == 4
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+ ctx.save_for_backward(input, grid)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
+ return grad_input, grad_grid
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, grid):
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
+ ctx.save_for_backward(grid)
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ _ = grad2_grad_grid # unused
+ grid, = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+ grad2_grid = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
+
+ assert not ctx.needs_input_grad[2]
+ return grad2_grad_output, grad2_input, grad2_grid
+
+#----------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/upfirdn2d.cpp b/ThirdParty/eg3d/torch_utils/ops/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c1769c3cbe4dd04f76f9ccef726680720e6f39c8
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/upfirdn2d.cpp
@@ -0,0 +1,111 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+
+static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
+ TORCH_CHECK(x.numel() > 0, "x has zero size");
+ TORCH_CHECK(f.numel() > 0, "f has zero size");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
+ TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
+ TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
+
+ // Initialize CUDA kernel parameters.
+ upfirdn2d_kernel_params p;
+ p.x = x.data_ptr();
+ p.f = f.data_ptr();
+ p.y = y.data_ptr();
+ p.up = make_int2(upx, upy);
+ p.down = make_int2(downx, downy);
+ p.pad0 = make_int2(padx0, pady0);
+ p.flip = (flip) ? 1 : 0;
+ p.gain = gain;
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
+
+ // Choose CUDA kernel.
+ upfirdn2d_kernel_spec spec;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ spec = choose_upfirdn2d_kernel(p);
+ });
+
+ // Set looping options.
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
+ p.loopMinor = spec.loopMinor;
+ p.loopX = spec.loopX;
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
+
+ // Compute grid size.
+ dim3 blockSize, gridSize;
+ if (spec.tileOutW < 0) // large
+ {
+ blockSize = dim3(4, 32, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
+ p.launchMajor);
+ }
+ else // small
+ {
+ blockSize = dim3(256, 1, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
+ p.launchMajor);
+ }
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("upfirdn2d", &upfirdn2d);
+}
+
+//------------------------------------------------------------------------
diff --git a/ThirdParty/eg3d/torch_utils/ops/upfirdn2d.cu b/ThirdParty/eg3d/torch_utils/ops/upfirdn2d.cu
new file mode 100644
index 0000000000000000000000000000000000000000..7d182d7b86a9058d0c007b13716d6e7f08207f42
--- /dev/null
+++ b/ThirdParty/eg3d/torch_utils/ops/upfirdn2d.cu
@@ -0,0 +1,388 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+static __device__ __forceinline__ int floor_div(int a, int b)
+{
+ int t = 1 - a / b;
+ return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// Generic CUDA implementation for large filters.
+
+template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Calculate thread index.
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
+ int outY = minorBase / p.launchMinor;
+ minorBase -= outY * p.launchMinor;
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Setup Y receptive field.
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
+ if (p.flip)
+ filterY = p.filterSize.y - 1 - filterY;
+
+ // Loop over major, minor, and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
+ {
+ int nc = major * p.sizeMinor + minor;
+ int n = nc / p.inSize.z;
+ int c = nc - n * p.inSize.z;
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
+ {
+ // Setup X receptive field.
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
+ if (p.flip)
+ filterX = p.filterSize.x - 1 - filterX;
+
+ // Initialize pointers.
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
+
+ // Inner loop.
+ scalar_t v = 0;
+ for (int y = 0; y < h; y++)
+ {
+ for (int x = 0; x < w; x++)
+ {
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
+ xp += p.inStride.x;
+ fp += filterStepX;
+ }
+ xp += p.inStride.y - w * p.inStride.x;
+ fp += filterStepY - w * filterStepX;
+ }
+
+ // Store result.
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// Specialized CUDA implementation for small filters.
+
+template
+static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
+ __shared__ volatile scalar_t sf[filterH][filterW];
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
+
+ // Calculate tile index.
+ int minorBase = blockIdx.x;
+ int tileOutY = minorBase / p.launchMinor;
+ minorBase -= tileOutY * p.launchMinor;
+ minorBase *= loopMinor;
+ tileOutY *= tileOutH;
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Load filter (flipped).
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
+ {
+ int fy = tapIdx / filterW;
+ int fx = tapIdx - fy * filterW;
+ scalar_t v = 0;
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
+ {
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
+ }
+ sf[fy][fx] = v;
+ }
+
+ // Loop over major and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ {
+ int baseNC = major * p.sizeMinor + minorBase;
+ int n = baseNC / p.inSize.z;
+ int baseC = baseNC - n * p.inSize.z;
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
+ {
+ // Load input pixels.
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
+ int tileInX = floor_div(tileMidX, upx);
+ int tileInY = floor_div(tileMidY, upy);
+ __syncthreads();
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
+ {
+ int relC = inIdx;
+ int relInX = relC / loopMinor;
+ int relInY = relInX / tileInW;
+ relC -= relInX * loopMinor;
+ relInX -= relInY * tileInW;
+ int c = baseC + relC;
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ sx[relInY][relInX][relC] = v;
+ }
+
+ // Loop over output pixels.
+ __syncthreads();
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
+ {
+ int relC = outIdx;
+ int relOutX = relC / loopMinor;
+ int relOutY = relOutX / tileOutW;
+ relC -= relOutX * loopMinor;
+ relOutX -= relOutY * tileOutW;
+ int c = baseC + relC;
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY;
+
+ // Setup receptive field.
+ int midX = tileMidX + relOutX * downx;
+ int midY = tileMidY + relOutY * downy;
+ int inX = floor_div(midX, upx);
+ int inY = floor_div(midY, upy);
+ int relInX = inX - tileInX;
+ int relInY = inY - tileInY;
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
+
+ // Inner loop.
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
+ {
+ scalar_t v = 0;
+ #pragma unroll
+ for (int y = 0; y < filterH / upy; y++)
+ #pragma unroll
+ for (int x = 0; x < filterW / upx; x++)
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
+{
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last
+
+ // No up/downsampling.
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+
+ // 2x upsampling.
+ if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ }
+ if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ }
+ if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small