diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ce21ee3c9528cc301542075929f2e4a3f5f4f223 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +ckpt +result +**/__pycache__/ +**/.DS_Store \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..71f1aa820859230b0ffcae3ca02103f982c06c62 --- /dev/null +++ b/app.py @@ -0,0 +1,139 @@ +import gradio as gr +import numpy as np +import glob +import torch +import random +from tempfile import NamedTemporaryFile +from infer_api import InferAPI +from PIL import Image + +config_canocalize = { + 'config_path': './configs/canonicalization-infer.yaml', +} +config_multiview = {} +config_slrm = { + 'config_path': './configs/mesh-slrm-infer.yaml' +} +config_refine = {} + +EXAMPLE_IMAGES = glob.glob("./input_cases/*") +EXAMPLE_APOSE_IMAGES = glob.glob("./input_cases_apose/*") + +infer_api = InferAPI(config_canocalize, config_multiview, config_slrm, config_refine) + +REMINDER = """ +### Reminder: +1. **Reference Image**: + - You can upload any reference image (with or without background). + - If the image has an alpha channel (transparency), background segmentation will be automatically performed. + - Alternatively, you can pre-segment the background using other tools and upload the result directly. + - A-pose images are also supported. + +2. Real person images generally work well, but note that normals may appear smoother than expected. You can try to use other monocular normal estimation models. + +3. The base human model in the output is uncolored due to potential NSFW concerns. If you need colored results, please refer to the official GitHub repository for instructions. +""" + +# 示例占位函数 - 需替换实际模型 +def arbitrary_to_apose(image, seed): + # convert image to PIL.Image + image = Image.fromarray(image) + return infer_api.genStage1(image, seed) + +def apose_to_multiview(apose_img, seed): + # convert image to PIL.Image + apose_img = Image.fromarray(apose_img) + return infer_api.genStage2(apose_img, seed, num_levels=1)[0]["images"] + +def multiview_to_mesh(images): + mesh_files = infer_api.genStage3(images) + return mesh_files + +def refine_mesh(apose_img, mesh1, mesh2, mesh3, seed): + apose_img = Image.fromarray(apose_img) + infer_api.genStage2(apose_img, seed, num_levels=2) + print(infer_api.multiview_infer.results.keys()) + refined = infer_api.genStage4([mesh1, mesh2, mesh3], infer_api.multiview_infer.results) + return refined + +with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo: + gr.Markdown(REMINDER) + with gr.Row(): + with gr.Column(): + gr.Markdown("## 1. Reference Image to A-pose Image") + input_image = gr.Image(label="Input Reference Image", type="numpy", width=384, height=384) + gr.Examples( + examples=EXAMPLE_IMAGES, + inputs=input_image, + label="Click to use sample images", + ) + seed_input = gr.Number( + label="Seed", + value=42, + precision=0, + interactive=True + ) + pose_btn = gr.Button("Convert") + with gr.Column(): + gr.Markdown("## 2. Multi-view Generation") + a_pose_image = gr.Image(label="A-pose Result", type="numpy", width=384, height=384) + gr.Examples( + examples=EXAMPLE_APOSE_IMAGES, + inputs=a_pose_image, + label="Click to use sample A-pose images", + ) + seed_input2 = gr.Number( + label="Seed", + value=42, + precision=0, + interactive=True + ) + view_btn = gr.Button("Generate Multi-view Images") + + with gr.Column(): + gr.Markdown("## 3. Semantic-aware Reconstruction") + multiview_gallery = gr.Gallery( + label="Multi-view results", + columns=2, + interactive=False, + height="None" + ) + mesh_btn = gr.Button("Reconstruct") + + with gr.Row(): + mesh_cols = [gr.Model3D(label=f"Mesh {i+1}", interactive=False, height=384) for i in range(3)] + full_mesh = gr.Model3D(label="Whole Mesh", height=384) + refine_btn = gr.Button("Refine") + + gr.Markdown("## 4. Mesh refinement") + with gr.Row(): + refined_meshes = [gr.Model3D(label=f"refined mesh {i+1}", height=384) for i in range(3)] + refined_full_mesh = gr.Model3D(label="refined whole mesh", height=384) + + # 交互逻辑 + pose_btn.click( + arbitrary_to_apose, + inputs=[input_image, seed_input], + outputs=a_pose_image + ) + + view_btn.click( + apose_to_multiview, + inputs=[a_pose_image, seed_input2], + outputs=multiview_gallery + ) + + mesh_btn.click( + multiview_to_mesh, + inputs=multiview_gallery, + outputs=[*mesh_cols, full_mesh] + ) + + refine_btn.click( + refine_mesh, + inputs=[a_pose_image, *mesh_cols, seed_input2], + outputs=[refined_meshes[2], refined_meshes[0], refined_meshes[1], refined_full_mesh] + ) + +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/blender/blender_lrm_script.py b/blender/blender_lrm_script.py new file mode 100755 index 0000000000000000000000000000000000000000..68d11591e0b118c902bfbdcb4823b07a3fc9716d --- /dev/null +++ b/blender/blender_lrm_script.py @@ -0,0 +1,1387 @@ +"""Blender script to render images of 3D models.""" + +import argparse +import json +import math +import os +import random +import sys +from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Set, Tuple + +import bpy +import numpy as np +from mathutils import Matrix, Vector +import pdb +MAX_DEPTH = 5.0 +import shutil +IMPORT_FUNCTIONS: Dict[str, Callable] = { + "obj": bpy.ops.import_scene.obj, + "glb": bpy.ops.import_scene.gltf, + "gltf": bpy.ops.import_scene.gltf, + "usd": bpy.ops.import_scene.usd, + "fbx": bpy.ops.import_scene.fbx, + "stl": bpy.ops.import_mesh.stl, + "usda": bpy.ops.import_scene.usda, + "dae": bpy.ops.wm.collada_import, + "ply": bpy.ops.import_mesh.ply, + "abc": bpy.ops.wm.alembic_import, + "blend": bpy.ops.wm.append, + "vrm": bpy.ops.import_scene.vrm, +} + +configs = { + "custom2": {"camera_pose": "z-circular-elevated", 'elevation_range': [0,0], "rotate": 0.0}, + "custom_top": {"camera_pose": "z-circular-elevated", 'elevation_range': [90,90], "rotate": 0.0, "render_num": 1}, + "custom_bottom": {"camera_pose": "z-circular-elevated", 'elevation_range': [-90,-90], "rotate": 0.0, "render_num": 1}, + "custom_face": {"camera_pose": "z-circular-elevated", 'elevation_range': [0,0], "rotate": 0.0, "render_num": 8}, + "random": {"camera_pose": "random", 'elevation_range': [-90,90], "rotate": 0.0, "render_num": 20}, +} + + +def reset_cameras() -> None: + """Resets the cameras in the scene to a single default camera.""" + # Delete all existing cameras + bpy.ops.object.select_all(action="DESELECT") + bpy.ops.object.select_by_type(type="CAMERA") + bpy.ops.object.delete() + + # Create a new camera with default properties + bpy.ops.object.camera_add() + + # Rename the new camera to 'NewDefaultCamera' + new_camera = bpy.context.active_object + new_camera.name = "Camera" + + # Set the new camera as the active camera for the scene + scene.camera = new_camera + + +def _sample_spherical( + radius_min: float = 1.5, + radius_max: float = 2.0, + maxz: float = 1.6, + minz: float = -0.75, +) -> np.ndarray: + """Sample a random point in a spherical shell. + + Args: + radius_min (float): Minimum radius of the spherical shell. + radius_max (float): Maximum radius of the spherical shell. + maxz (float): Maximum z value of the spherical shell. + minz (float): Minimum z value of the spherical shell. + + Returns: + np.ndarray: A random (x, y, z) point in the spherical shell. + """ + correct = False + vec = np.array([0, 0, 0]) + while not correct: + vec = np.random.uniform(-1, 1, 3) + # vec[2] = np.abs(vec[2]) + radius = np.random.uniform(radius_min, radius_max, 1) + vec = vec / np.linalg.norm(vec, axis=0) * radius[0] + if maxz > vec[2] > minz: + correct = True + return vec + + +def randomize_camera( + radius_min: float = 1.5, + radius_max: float = 2.2, + maxz: float = 2.2, + minz: float = -2.2, + only_northern_hemisphere: bool = False, +) -> bpy.types.Object: + """Randomizes the camera location and rotation inside of a spherical shell. + + Args: + radius_min (float, optional): Minimum radius of the spherical shell. Defaults to + 1.5. + radius_max (float, optional): Maximum radius of the spherical shell. Defaults to + 2.0. + maxz (float, optional): Maximum z value of the spherical shell. Defaults to 1.6. + minz (float, optional): Minimum z value of the spherical shell. Defaults to + -0.75. + only_northern_hemisphere (bool, optional): Whether to only sample points in the + northern hemisphere. Defaults to False. + + Returns: + bpy.types.Object: The camera object. + """ + + x, y, z = _sample_spherical( + radius_min=radius_min, radius_max=radius_max, maxz=maxz, minz=minz + ) + camera = bpy.data.objects["Camera"] + + # only positive z + if only_northern_hemisphere: + z = abs(z) + + camera.location = Vector(np.array([x, y, z])) + + direction = -camera.location + rot_quat = direction.to_track_quat("-Z", "Y") + camera.rotation_euler = rot_quat.to_euler() + + return camera + + +cached_cameras = [] + +def randomize_camera_with_cache( + radius_min: float = 1.5, + radius_max: float = 2.2, + maxz: float = 2.2, + minz: float = -2.2, + only_northern_hemisphere: bool = False, + idx: int = 0, +) -> bpy.types.Object: + + assert len(cached_cameras) >= idx + + if len(cached_cameras) == idx: + x, y, z = _sample_spherical( + radius_min=radius_min, radius_max=radius_max, maxz=maxz, minz=minz + ) + cached_cameras.append((x, y, z)) + else: + x, y, z = cached_cameras[idx] + + camera = bpy.data.objects["Camera"] + + # only positive z + if only_northern_hemisphere: + z = abs(z) + + camera.location = Vector(np.array([x, y, z])) + + direction = -camera.location + rot_quat = direction.to_track_quat("-Z", "Y") + camera.rotation_euler = rot_quat.to_euler() + + return camera + + +def set_camera(direction, camera_dist=2.0, camera_offset=0.0): + camera = bpy.data.objects["Camera"] + camera_pos = -camera_dist * direction + if type(camera_offset) == float: + camera_offset = Vector(np.array([0., 0., 0.])) + camera_pos += camera_offset + camera.location = camera_pos + + # https://blender.stackexchange.com/questions/5210/pointing-the-camera-in-a-particular-direction-programmatically + rot_quat = direction.to_track_quat("-Z", "Y") + camera.rotation_euler = rot_quat.to_euler() + return camera + + +def _set_camera_at_size(i: int, scale: float = 1.5) -> bpy.types.Object: + """Debugging function to set the camera on the 6 faces of a cube. + + Args: + i (int): Index of the face of the cube. + scale (float, optional): Scale of the cube. Defaults to 1.5. + + Returns: + bpy.types.Object: The camera object. + """ + if i == 0: + x, y, z = scale, 0, 0 + elif i == 1: + x, y, z = -scale, 0, 0 + elif i == 2: + x, y, z = 0, scale, 0 + elif i == 3: + x, y, z = 0, -scale, 0 + elif i == 4: + x, y, z = 0, 0, scale + elif i == 5: + x, y, z = 0, 0, -scale + else: + raise ValueError(f"Invalid index: i={i}, must be int in range [0, 5].") + camera = bpy.data.objects["Camera"] + camera.location = Vector(np.array([x, y, z])) + direction = -camera.location + rot_quat = direction.to_track_quat("-Z", "Y") + camera.rotation_euler = rot_quat.to_euler() + return camera + + +def _create_light( + name: str, + light_type: Literal["POINT", "SUN", "SPOT", "AREA"], + location: Tuple[float, float, float], + rotation: Tuple[float, float, float], + energy: float, + use_shadow: bool = False, + specular_factor: float = 1.0, +): + """Creates a light object. + + Args: + name (str): Name of the light object. + light_type (Literal["POINT", "SUN", "SPOT", "AREA"]): Type of the light. + location (Tuple[float, float, float]): Location of the light. + rotation (Tuple[float, float, float]): Rotation of the light. + energy (float): Energy of the light. + use_shadow (bool, optional): Whether to use shadows. Defaults to False. + specular_factor (float, optional): Specular factor of the light. Defaults to 1.0. + + Returns: + bpy.types.Object: The light object. + """ + + light_data = bpy.data.lights.new(name=name, type=light_type) + light_object = bpy.data.objects.new(name, light_data) + bpy.context.collection.objects.link(light_object) + light_object.location = location + light_object.rotation_euler = rotation + light_data.use_shadow = use_shadow + light_data.specular_factor = specular_factor + light_data.energy = energy + return light_object + + +def reset_scene() -> None: + """Resets the scene to a clean state. + + Returns: + None + """ + # delete everything that isn't part of a camera or a light + for obj in bpy.data.objects: + if obj.type not in {"CAMERA", "LIGHT"}: + bpy.data.objects.remove(obj, do_unlink=True) + + # delete all the materials + for material in bpy.data.materials: + bpy.data.materials.remove(material, do_unlink=True) + + # delete all the textures + for texture in bpy.data.textures: + bpy.data.textures.remove(texture, do_unlink=True) + + # delete all the images + for image in bpy.data.images: + bpy.data.images.remove(image, do_unlink=True) + + # delete all the collider collections + for collider in bpy.data.collections: + if collider.name != "Collection": + bpy.data.collections.remove(collider, do_unlink=True) + + +def load_object(object_path: str) -> None: + """Loads a model with a supported file extension into the scene. + + Args: + object_path (str): Path to the model file. + + Raises: + ValueError: If the file extension is not supported. + + Returns: + None + """ + file_extension = object_path.split(".")[-1].lower() + if file_extension is None: + raise ValueError(f"Unsupported file type: {object_path}") + + if file_extension == "usdz": + # install usdz io package + dirname = os.path.dirname(os.path.realpath(__file__)) + usdz_package = os.path.join(dirname, "io_scene_usdz.zip") + bpy.ops.preferences.addon_install(filepath=usdz_package) + # enable it + addon_name = "io_scene_usdz" + bpy.ops.preferences.addon_enable(module=addon_name) + # import the usdz + from io_scene_usdz.import_usdz import import_usdz + + import_usdz(context, filepath=object_path, materials=True, animations=True) + return None + + # load from existing import functions + import_function = IMPORT_FUNCTIONS[file_extension] + + if file_extension == "blend": + import_function(directory=object_path, link=False) + elif file_extension in {"glb", "gltf"}: + import_function(filepath=object_path, merge_vertices=True) + else: + import_function(filepath=object_path) + + +def scene_bbox( + single_obj: Optional[bpy.types.Object] = None, ignore_matrix: bool = False +) -> Tuple[Vector, Vector]: + """Returns the bounding box of the scene. + + Taken from Shap-E rendering script + (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82) + + Args: + single_obj (Optional[bpy.types.Object], optional): If not None, only computes + the bounding box for the given object. Defaults to None. + ignore_matrix (bool, optional): Whether to ignore the object's matrix. Defaults + to False. + + Raises: + RuntimeError: If there are no objects in the scene. + + Returns: + Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box. + """ + bbox_min = (math.inf,) * 3 + bbox_max = (-math.inf,) * 3 + found = False + for obj in get_scene_meshes() if single_obj is None else [single_obj]: + found = True + for coord in obj.bound_box: + coord = Vector(coord) + if not ignore_matrix: + coord = obj.matrix_world @ coord + bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) + bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) + + if not found: + raise RuntimeError("no objects in scene to compute bounding box for") + + return Vector(bbox_min), Vector(bbox_max) + + +def get_scene_root_objects() -> Generator[bpy.types.Object, None, None]: + """Returns all root objects in the scene. + + Yields: + Generator[bpy.types.Object, None, None]: Generator of all root objects in the + scene. + """ + for obj in bpy.context.scene.objects.values(): + if not obj.parent: + yield obj + + +def get_scene_meshes() -> Generator[bpy.types.Object, None, None]: + """Returns all meshes in the scene. + + Yields: + Generator[bpy.types.Object, None, None]: Generator of all meshes in the scene. + """ + for obj in bpy.context.scene.objects.values(): + if isinstance(obj.data, (bpy.types.Mesh)): + yield obj + + +def get_3x4_RT_matrix_from_blender(cam: bpy.types.Object) -> Matrix: + """Returns the 3x4 RT matrix from the given camera. + + Taken from Zero123, which in turn was taken from + https://github.com/panmari/stanford-shapenet-renderer/blob/master/render_blender.py + + Args: + cam (bpy.types.Object): The camera object. + + Returns: + Matrix: The 3x4 RT matrix from the given camera. + """ + # Use matrix_world instead to account for all constraints + location, rotation = cam.matrix_world.decompose()[0:2] + R_world2bcam = rotation.to_matrix().transposed() + + # Use location from matrix_world to account for constraints: + T_world2bcam = -1 * R_world2bcam @ location + + # put into 3x4 matrix + RT = Matrix( + ( + R_world2bcam[0][:] + (T_world2bcam[0],), + R_world2bcam[1][:] + (T_world2bcam[1],), + R_world2bcam[2][:] + (T_world2bcam[2],), + ) + ) + return RT + + +def delete_invisible_objects() -> None: + """Deletes all invisible objects in the scene. + + Returns: + None + """ + bpy.ops.object.select_all(action="DESELECT") + for obj in scene.objects: + if obj.hide_viewport or obj.hide_render: + obj.hide_viewport = False + obj.hide_render = False + obj.hide_select = False + obj.select_set(True) + bpy.ops.object.delete() + + # Delete invisible collections + invisible_collections = [col for col in bpy.data.collections if col.hide_viewport] + for col in invisible_collections: + bpy.data.collections.remove(col) + + +def normalize_scene() -> None: + """Normalizes the scene by scaling and translating it to fit in a unit cube centered + at the origin. + + Mostly taken from the Point-E / Shap-E rendering script + (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112), + but fix for multiple root objects: (see bug report here: + https://github.com/openai/shap-e/pull/60). + + Returns: + None + """ + if len(list(get_scene_root_objects())) > 1: + # create an empty object to be used as a parent for all root objects + parent_empty = bpy.data.objects.new("ParentEmpty", None) + bpy.context.scene.collection.objects.link(parent_empty) + + # parent all root objects to the empty object + for obj in get_scene_root_objects(): + if obj != parent_empty: + obj.parent = parent_empty + + bbox_min, bbox_max = scene_bbox() + scale = 1 / max(bbox_max - bbox_min) + for obj in get_scene_root_objects(): + obj.scale = obj.scale * scale + + # Apply scale to matrix_world. + bpy.context.view_layer.update() + bbox_min, bbox_max = scene_bbox() + offset = -(bbox_min + bbox_max) / 2 + for obj in get_scene_root_objects(): + obj.matrix_world.translation += offset + bpy.ops.object.select_all(action="DESELECT") + + # unparent the camera + bpy.data.objects["Camera"].parent = None + + +def delete_missing_textures() -> Dict[str, Any]: + """Deletes all missing textures in the scene. + + Returns: + Dict[str, Any]: Dictionary with keys "count", "files", and "file_path_to_color". + "count" is the number of missing textures, "files" is a list of the missing + texture file paths, and "file_path_to_color" is a dictionary mapping the + missing texture file paths to a random color. + """ + missing_file_count = 0 + out_files = [] + file_path_to_color = {} + + # Check all materials in the scene + for material in bpy.data.materials: + if material.use_nodes: + for node in material.node_tree.nodes: + if node.type == "TEX_IMAGE": + image = node.image + if image is not None: + file_path = bpy.path.abspath(image.filepath) + if file_path == "": + # means it's embedded + continue + + if not os.path.exists(file_path): + # Find the connected Principled BSDF node + connected_node = node.outputs[0].links[0].to_node + + if connected_node.type == "BSDF_PRINCIPLED": + if file_path not in file_path_to_color: + # Set a random color for the unique missing file path + random_color = [random.random() for _ in range(3)] + file_path_to_color[file_path] = random_color + [1] + + connected_node.inputs[ + "Base Color" + ].default_value = file_path_to_color[file_path] + + # Delete the TEX_IMAGE node + material.node_tree.nodes.remove(node) + missing_file_count += 1 + out_files.append(image.filepath) + return { + "count": missing_file_count, + "files": out_files, + "file_path_to_color": file_path_to_color, + } + + +def _get_random_color() -> Tuple[float, float, float, float]: + """Generates a random RGB-A color. + + The alpha value is always 1. + + Returns: + Tuple[float, float, float, float]: A random RGB-A color. Each value is in the + range [0, 1]. + """ + return (random.random(), random.random(), random.random(), 1) + + +def _apply_color_to_object( + obj: bpy.types.Object, color: Tuple[float, float, float, float] +) -> None: + """Applies the given color to the object. + + Args: + obj (bpy.types.Object): The object to apply the color to. + color (Tuple[float, float, float, float]): The color to apply to the object. + + Returns: + None + """ + mat = bpy.data.materials.new(name=f"RandomMaterial_{obj.name}") + mat.use_nodes = True + nodes = mat.node_tree.nodes + principled_bsdf = nodes.get("Principled BSDF") + if principled_bsdf: + principled_bsdf.inputs["Base Color"].default_value = color + obj.data.materials.append(mat) + + +class MetadataExtractor: + """Class to extract metadata from a Blender scene.""" + + def __init__( + self, object_path: str, scene: bpy.types.Scene, bdata: bpy.types.BlendData + ) -> None: + """Initializes the MetadataExtractor. + + Args: + object_path (str): Path to the object file. + scene (bpy.types.Scene): The current scene object from `bpy.context.scene`. + bdata (bpy.types.BlendData): The current blender data from `bpy.data`. + + Returns: + None + """ + self.object_path = object_path + self.scene = scene + self.bdata = bdata + + def get_poly_count(self) -> int: + """Returns the total number of polygons in the scene.""" + total_poly_count = 0 + for obj in self.scene.objects: + if obj.type == "MESH": + total_poly_count += len(obj.data.polygons) + return total_poly_count + + def get_vertex_count(self) -> int: + """Returns the total number of vertices in the scene.""" + total_vertex_count = 0 + for obj in self.scene.objects: + if obj.type == "MESH": + total_vertex_count += len(obj.data.vertices) + return total_vertex_count + + def get_edge_count(self) -> int: + """Returns the total number of edges in the scene.""" + total_edge_count = 0 + for obj in self.scene.objects: + if obj.type == "MESH": + total_edge_count += len(obj.data.edges) + return total_edge_count + + def get_lamp_count(self) -> int: + """Returns the number of lamps in the scene.""" + return sum(1 for obj in self.scene.objects if obj.type == "LIGHT") + + def get_mesh_count(self) -> int: + """Returns the number of meshes in the scene.""" + return sum(1 for obj in self.scene.objects if obj.type == "MESH") + + def get_material_count(self) -> int: + """Returns the number of materials in the scene.""" + return len(self.bdata.materials) + + def get_object_count(self) -> int: + """Returns the number of objects in the scene.""" + return len(self.bdata.objects) + + def get_animation_count(self) -> int: + """Returns the number of animations in the scene.""" + return len(self.bdata.actions) + + def get_linked_files(self) -> List[str]: + """Returns the filepaths of all linked files.""" + image_filepaths = self._get_image_filepaths() + material_filepaths = self._get_material_filepaths() + linked_libraries_filepaths = self._get_linked_libraries_filepaths() + + all_filepaths = ( + image_filepaths | material_filepaths | linked_libraries_filepaths + ) + if "" in all_filepaths: + all_filepaths.remove("") + return list(all_filepaths) + + def _get_image_filepaths(self) -> Set[str]: + """Returns the filepaths of all images used in the scene.""" + filepaths = set() + for image in self.bdata.images: + if image.source == "FILE": + filepaths.add(bpy.path.abspath(image.filepath)) + return filepaths + + def _get_material_filepaths(self) -> Set[str]: + """Returns the filepaths of all images used in materials.""" + filepaths = set() + for material in self.bdata.materials: + if material.use_nodes: + for node in material.node_tree.nodes: + if node.type == "TEX_IMAGE": + image = node.image + if image is not None: + filepaths.add(bpy.path.abspath(image.filepath)) + return filepaths + + def _get_linked_libraries_filepaths(self) -> Set[str]: + """Returns the filepaths of all linked libraries.""" + filepaths = set() + for library in self.bdata.libraries: + filepaths.add(bpy.path.abspath(library.filepath)) + return filepaths + + def get_scene_size(self) -> Dict[str, list]: + """Returns the size of the scene bounds in meters.""" + bbox_min, bbox_max = scene_bbox() + return {"bbox_max": list(bbox_max), "bbox_min": list(bbox_min)} + + def get_shape_key_count(self) -> int: + """Returns the number of shape keys in the scene.""" + total_shape_key_count = 0 + for obj in self.scene.objects: + if obj.type == "MESH": + shape_keys = obj.data.shape_keys + if shape_keys is not None: + total_shape_key_count += ( + len(shape_keys.key_blocks) - 1 + ) # Subtract 1 to exclude the Basis shape key + return total_shape_key_count + + def get_armature_count(self) -> int: + """Returns the number of armatures in the scene.""" + total_armature_count = 0 + for obj in self.scene.objects: + if obj.type == "ARMATURE": + total_armature_count += 1 + return total_armature_count + + def read_file_size(self) -> int: + """Returns the size of the file in bytes.""" + return os.path.getsize(self.object_path) + + def get_metadata(self) -> Dict[str, Any]: + """Returns the metadata of the scene. + + Returns: + Dict[str, Any]: Dictionary of the metadata with keys for "file_size", + "poly_count", "vert_count", "edge_count", "material_count", "object_count", + "lamp_count", "mesh_count", "animation_count", "linked_files", "scene_size", + "shape_key_count", and "armature_count". + """ + return { + "file_size": self.read_file_size(), + "poly_count": self.get_poly_count(), + "vert_count": self.get_vertex_count(), + "edge_count": self.get_edge_count(), + "material_count": self.get_material_count(), + "object_count": self.get_object_count(), + "lamp_count": self.get_lamp_count(), + "mesh_count": self.get_mesh_count(), + "animation_count": self.get_animation_count(), + "linked_files": self.get_linked_files(), + "scene_size": self.get_scene_size(), + "shape_key_count": self.get_shape_key_count(), + "armature_count": self.get_armature_count(), + } + +def pan_camera(time, axis="Z", camera_dist=2.0, elevation=-0.1, camera_offset=0.0): + angle = time * math.pi * 2 - math.pi / 2 # start from -90 degree + direction = [-math.cos(angle), -math.sin(angle), -elevation] + assert axis in ["X", "Y", "Z"] + if axis == "X": + direction = [direction[2], *direction[:2]] + elif axis == "Y": + direction = [direction[0], -elevation, direction[1]] + direction = Vector(direction).normalized() + camera = set_camera(direction, camera_dist=camera_dist, camera_offset=camera_offset) + return camera + + +def pan_camera_along(time, pose="alone-x-rotate", camera_dist=2.0, rotate=0.0): + angle = time * math.pi * 2 + # direction_plane = [-math.cos(angle), -math.sin(angle), 0] + x_new = math.cos(angle) + y_new = math.cos(rotate) * math.sin(angle) + z_new = math.sin(rotate) * math.sin(angle) + direction = [-x_new, -y_new, -z_new] + assert pose in ["alone-x-rotate"] + direction = Vector(direction).normalized() + camera = set_camera(direction, camera_dist=camera_dist) + return camera + +def pan_camera_by_angle(angle, axis="Z", camera_dist=2.0, elevation=-0.1 ): + direction = [-math.cos(angle), -math.sin(angle), -elevation] + assert axis in ["X", "Y", "Z"] + if axis == "X": + direction = [direction[2], *direction[:2]] + elif axis == "Y": + direction = [direction[0], -elevation, direction[1]] + direction = Vector(direction).normalized() + camera = set_camera(direction, camera_dist=camera_dist) + return camera + +def z_circular_custom_track(time, + camera_dist, + azimuth_shift = [-9, 9], + init_elevation = 0.0, + elevation_shift = [-5, 5]): + + adjusted_azimuth = (-math.degrees(math.pi / 2) + + time * 360 + + np.random.uniform(low=azimuth_shift[0], high=azimuth_shift[1])) + + # Add random noise to the elevation + adjusted_elevation = init_elevation + np.random.uniform(low=elevation_shift[0], high=elevation_shift[1]) + return math.radians(adjusted_azimuth), math.radians(adjusted_elevation), camera_dist + + +def place_camera(time, camera_pose_mode="random", camera_dist=2.0, rotate=0.0, elevation=0.0, camera_offset=0.0, idx=0): + if camera_pose_mode == "z-circular-elevated": + cam = pan_camera(time, axis="Z", camera_dist=camera_dist, elevation=elevation, camera_offset=camera_offset) + elif camera_pose_mode == 'alone-x-rotate': + cam = pan_camera_along(time, pose=camera_pose_mode, camera_dist=camera_dist, rotate=rotate) + elif camera_pose_mode == 'z-circular-elevated-noise': + angle, elevation, camera_dist = z_circular_custom_track(time, camera_dist=camera_dist, init_elevation=elevation) + cam = pan_camera_by_angle(angle, axis="Z", camera_dist=camera_dist, elevation=elevation) + elif camera_pose_mode == 'random': + cam = randomize_camera_with_cache(radius_min=camera_dist, radius_max=camera_dist, maxz=114514., minz=-114514., idx=idx) + else: + raise ValueError(f"Unknown camera pose mode: {camera_pose_mode}") + return cam + + +def setup_nodes(output_path, capturing_material_alpha: bool = False): + tree = bpy.context.scene.node_tree + links = tree.links + + for node in tree.nodes: + tree.nodes.remove(node) + + # Helpers to perform math on links and constants. + def node_op(op: str, *args, clamp=False): + node = tree.nodes.new(type="CompositorNodeMath") + node.operation = op + if clamp: + node.use_clamp = True + for i, arg in enumerate(args): + if isinstance(arg, (int, float)): + node.inputs[i].default_value = arg + else: + links.new(arg, node.inputs[i]) + return node.outputs[0] + + def node_clamp(x, maximum=1.0): + return node_op("MINIMUM", x, maximum) + + def node_mul(x, y, **kwargs): + return node_op("MULTIPLY", x, y, **kwargs) + + input_node = tree.nodes.new(type="CompositorNodeRLayers") + input_node.scene = bpy.context.scene + + input_sockets = {} + for output in input_node.outputs: + input_sockets[output.name] = output + + if capturing_material_alpha: + color_socket = input_sockets["Image"] + else: + raw_color_socket = input_sockets["Image"] + + # We apply sRGB here so that our fixed-point depth map and material + # alpha values are not sRGB, and so that we perform ambient+diffuse + # lighting in linear RGB space. + color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace") + color_node.from_color_space = "Linear" + color_node.to_color_space = "sRGB" + tree.links.new(raw_color_socket, color_node.inputs[0]) + color_socket = color_node.outputs[0] + split_node = tree.nodes.new(type="CompositorNodeSepRGBA") + tree.links.new(color_socket, split_node.inputs[0]) + # Create separate file output nodes for every channel we care about. + # The process calling this script must decide how to recombine these + # channels, possibly into a single image. + for i, channel in enumerate("rgba") if not capturing_material_alpha else [(0, "MatAlpha")]: + output_node = tree.nodes.new(type="CompositorNodeOutputFile") + output_node.base_path = f"{output_path}_{channel}" + links.new(split_node.outputs[i], output_node.inputs[0]) + if capturing_material_alpha: + # No need to re-write depth here. + return + + depth_out = node_clamp(node_mul(input_sockets["Depth"], 1 / MAX_DEPTH)) + output_node = tree.nodes.new(type="CompositorNodeOutputFile") + output_node.format.file_format = 'OPEN_EXR' + output_node.base_path = f"{output_path}_depth" + links.new(depth_out, output_node.inputs[0]) + + # Add normal map output + normal_out = input_sockets["Normal"] + + # Scale normal by 0.5 + scale_normal = tree.nodes.new(type="CompositorNodeMixRGB") + scale_normal.blend_type = 'MULTIPLY' + scale_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 1) + links.new(normal_out, scale_normal.inputs[1]) + + # Bias normal by 0.5 + bias_normal = tree.nodes.new(type="CompositorNodeMixRGB") + bias_normal.blend_type = 'ADD' + bias_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 0) + links.new(scale_normal.outputs[0], bias_normal.inputs[1]) + + # Output the transformed normal map + normal_file_output = tree.nodes.new(type="CompositorNodeOutputFile") + normal_file_output.base_path = f"{output_path}_normal" + normal_file_output.format.file_format = 'OPEN_EXR' + links.new(bias_normal.outputs[0], normal_file_output.inputs[0]) + + +def setup_nodes_semantic(output_path, capturing_material_alpha: bool = False): + tree = bpy.context.scene.node_tree + links = tree.links + + for node in tree.nodes: + tree.nodes.remove(node) + + # Helpers to perform math on links and constants. + def node_op(op: str, *args, clamp=False): + node = tree.nodes.new(type="CompositorNodeMath") + node.operation = op + if clamp: + node.use_clamp = True + for i, arg in enumerate(args): + if isinstance(arg, (int, float)): + node.inputs[i].default_value = arg + else: + links.new(arg, node.inputs[i]) + return node.outputs[0] + + def node_clamp(x, maximum=1.0): + return node_op("MINIMUM", x, maximum) + + def node_mul(x, y, **kwargs): + return node_op("MULTIPLY", x, y, **kwargs) + + input_node = tree.nodes.new(type="CompositorNodeRLayers") + input_node.scene = bpy.context.scene + + input_sockets = {} + for output in input_node.outputs: + input_sockets[output.name] = output + + if capturing_material_alpha: + color_socket = input_sockets["Image"] + else: + raw_color_socket = input_sockets["Image"] + # We apply sRGB here so that our fixed-point depth map and material + # alpha values are not sRGB, and so that we perform ambient+diffuse + # lighting in linear RGB space. + color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace") + color_node.from_color_space = "Linear" + color_node.to_color_space = "sRGB" + tree.links.new(raw_color_socket, color_node.inputs[0]) + color_socket = color_node.outputs[0] + + +def render_object( + object_file: str, + num_renders: int, + only_northern_hemisphere: bool, + output_dir: str, +) -> None: + """Saves rendered images with its camera matrix and metadata of the object. + + Args: + object_file (str): Path to the object file. + num_renders (int): Number of renders to save of the object. + only_northern_hemisphere (bool): Whether to only render sides of the object that + are in the northern hemisphere. This is useful for rendering objects that + are photogrammetrically scanned, as the bottom of the object often has + holes. + output_dir (str): Path to the directory where the rendered images and metadata + will be saved. + + Returns: + None + """ + os.makedirs(output_dir, exist_ok=True) + + # load the object + if object_file.endswith(".blend"): + bpy.ops.object.mode_set(mode="OBJECT") + reset_cameras() + delete_invisible_objects() + else: + reset_scene() + load_object(object_file) + + # Set up cameras + cam = scene.objects["Camera"] + cam.data.lens = 35 + cam.data.sensor_width = 32 + + # Set up camera constraints + cam_constraint = cam.constraints.new(type="TRACK_TO") + cam_constraint.track_axis = "TRACK_NEGATIVE_Z" + cam_constraint.up_axis = "UP_Y" + + # Extract the metadata. This must be done before normalizing the scene to get + # accurate bounding box information. + metadata_extractor = MetadataExtractor( + object_path=object_file, scene=scene, bdata=bpy.data + ) + metadata = metadata_extractor.get_metadata() + + # delete all objects that are not meshes + if object_file.lower().endswith(".usdz") or object_file.lower().endswith(".vrm"): + # don't delete missing textures on usdz files, lots of them are embedded + missing_textures = None + else: + missing_textures = delete_missing_textures() + metadata["missing_textures"] = missing_textures + metadata["random_color"] = None + + # save metadata + metadata_path = os.path.join(output_dir, "metadata.json") + os.makedirs(os.path.dirname(metadata_path), exist_ok=True) + with open(metadata_path, "w", encoding="utf-8") as f: + json.dump(metadata, f, sort_keys=True, indent=2) + + # normalize the scene + normalize_scene() + + # cancel edge rim lighting in vrm files + if object_file.endswith(".vrm"): + for i in bpy.data.materials: + i.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.rim_lighting_mix_factor = 0.0 + i.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.matcap_texture.index.source = None + i.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.outline_width_factor = 0.0 + + # rotate two arms to A-pose + if object_file.endswith(".vrm"): + armature = [ i for i in bpy.data.objects if 'Armature' in i.name ][0] + bpy.context.view_layer.objects.active = armature + bpy.ops.object.mode_set(mode='POSE') + pbone1 = armature.pose.bones['J_Bip_L_UpperArm'] + pbone2 = armature.pose.bones['J_Bip_R_UpperArm'] + pbone1.rotation_mode = 'XYZ' + pbone2.rotation_mode = 'XYZ' + pbone1.rotation_euler.rotate_axis('X', math.radians(-45)) + pbone2.rotation_euler.rotate_axis('X', math.radians(-45)) + bpy.ops.object.mode_set(mode='OBJECT') + + def printInfo(): + print("====== Objects ======") + for i in bpy.data.objects: + print(i.name) + print("====== Materials ======") + for i in bpy.data.materials: + print(i.name) + + def parse_material(): + hair_mats = [] + cloth_mats = [] + face_mats = [] + body_mats = [] + + # main hair material + if 'Hair' in bpy.data.objects: + hair_mats = [i.name for i in bpy.data.objects['Hair'].data.materials if 'MToon Outline' not in i.name] + else: + flag = False + for i in bpy.data.objects: + if i.name[:4] == 'Hair' and bpy.data.objects[i.name].data: + hair_mats += [i.name for i in bpy.data.objects[i.name].data.materials if 'MToon Outline' not in i.name] + flag = True + if not flag: + if 'Hairs' in bpy.data.objects and bpy.data.objects['Hairs'].data: + hair_mats = [i.name for i in bpy.data.objects['Hairs'].data.materials if 'MToon Outline' not in i.name] + else: + for i in bpy.data.materials: + if 'HAIR' in i.name and 'MToon Outline' not in i.name: + hair_mats.append(i.name) + if len(hair_mats) == 0: + printInfo() + with open('error.txt', 'a+') as f: + f.write(object_file + '\t' + 'Cannot find main hair material\t' + str([iii.name for iii in bpy.data.objects]) + '\n') + raise ValueError("Cannot find main hair material") + + # face material + if 'Face' in bpy.data.objects: + face_mats = [i.name for i in bpy.data.objects['Face'].data.materials if 'MToon Outline' not in i.name] + else: + for i in bpy.data.materials: + if 'FACE' in i.name and 'MToon Outline' not in i.name: + face_mats.append(i.name) + elif 'Face' in i.name and 'SKIN' in i.name and 'MToon Outline' not in i.name: + face_mats.append(i.name) + if len(face_mats) == 0: + printInfo() + with open('error.txt', 'a+') as f: + f.write(object_file + '\t' + 'Cannot find face material\t' + str([iii.name for iii in bpy.data.objects]) + '\n') + raise ValueError("Cannot find face material") + + # loop + for i in bpy.data.materials: + if 'MToon Outline' in i.name: + continue + elif 'CLOTH' in i.name: + if 'Shoes' in i.name: + body_mats.append(i.name) + elif 'Accessory' in i.name: + if 'CatEar' in i.name: + hair_mats.append(i.name) + else: + cloth_mats.append(i.name) + elif any( name in i.name for name in ['Tops', 'Bottoms', 'Onepice'] ): + cloth_mats.append(i.name) + else: + raise ValueError(f"Unknown cloth material: {i.name}") + elif 'Body' in i.name and 'SKIN' in i.name: + body_mats.append(i.name) + elif i.name in hair_mats or i.name in face_mats: + continue + elif 'HairBack' in i.name and 'HAIR' in i.name: + hair_mats.append(i.name) + elif 'EYE' in i.name: + face_mats.append(i.name) + elif 'Face' in i.name and 'SKIN' in i.name: + face_mats.append(i.name) + else: + print("hair_mats", hair_mats) + print("cloth_mats", cloth_mats) + print("face_mats", face_mats) + print("body_mats", body_mats) + with open('error.txt', 'a+') as f: + f.write(object_file + '\t' + 'Cannot find material\t' + i.name + '\n') + raise ValueError(f"Unknown material: {i.name}") + + return hair_mats, cloth_mats, face_mats, body_mats + + hair_mats, cloth_mats, face_mats, body_mats = parse_material() + + # get bounding box of face + def get_face_bbox(): + if 'Face' in bpy.data.objects: + face = bpy.data.objects['Face'] + bbox_min, bbox_max = scene_bbox(face) + return bbox_min, bbox_max + else: + bbox_min, bbox_max = scene_bbox() + for i in bpy.data.objects: + if i.data.materials and i.data.materials[0].name in face_mats: + face = i + cur_bbox_min, cur_bbox_max = scene_bbox(face) + bbox_min = np.minimum(bbox_min, cur_bbox_min) + bbox_max = np.maximum(bbox_max, cur_bbox_max) + return bbox_min, bbox_max + + def assign_color(material_name, color): + material = bpy.data.materials.get(material_name) + if material: + material.vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (1, 1, 1, 1) + image = material.vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_texture.index.source + if image: + pixels = np.array(image.pixels[:]) + width, height = image.size + num_channels = 4 + pixels = pixels.reshape((height, width, num_channels)) + srgb_pixels = np.clip(np.power(pixels, 1/2.2), 0.0, 1.0) + print("Image converted to NumPy array") + + # Step 2: Edit the NumPy array + srgb_pixels[..., 0] = color[0] + srgb_pixels[..., 1] = color[1] + srgb_pixels[..., 2] = color[2] + edited_image_rgba = srgb_pixels + + # Step 3: Convert the edited NumPy array back to a Blender image + edited_image_flat = edited_image_rgba.astype(np.float32) + edited_image_flat = edited_image_flat.flatten() + edited_image_name = "Edited_Texture" + edited_blender_image = bpy.data.images.new(edited_image_name, width, height, alpha=True) + edited_blender_image.pixels = edited_image_flat + material.vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_texture.index.source = edited_blender_image + print(f"Edited image assigned to {material_name}") + + material.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.shade_color_factor = (1, 1, 1) + image = material.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.shade_multiply_texture.index.source + if image: + pixels = np.array(image.pixels[:]) + width, height = image.size + num_channels = 4 + pixels = pixels.reshape((height, width, num_channels)) + srgb_pixels = np.clip(np.power(pixels, 1/2.2), 0.0, 1.0) + print("Image converted to NumPy array") + + # Step 2: Edit the NumPy array + srgb_pixels[..., 0] = color[0] + srgb_pixels[..., 1] = color[1] + srgb_pixels[..., 2] = color[2] + edited_image_rgba = srgb_pixels + + # Step 3: Convert the edited NumPy array back to a Blender image + edited_image_flat = edited_image_rgba.astype(np.float32) + edited_image_flat = edited_image_flat.flatten() + edited_image_name = "Edited_Texture" + edited_blender_image = bpy.data.images.new(edited_image_name, width, height, alpha=True) + edited_blender_image.pixels = edited_image_flat + material.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.shade_multiply_texture.index.source = edited_blender_image + print(f"Edited image assigned to {material_name}") + material.vrm_addon_extension.mtoon1.extensions.khr_materials_emissive_strength.emissive_strength = 0.0 + + def assign_transparency(material_name, alpha): + material = bpy.data.materials.get(material_name) + if material: + material.vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (1, 1, 1, alpha) + + # render the images + use_workbench = bpy.context.scene.render.engine == "BLENDER_WORKBENCH" + + face_bbox_min, face_bbox_max = get_face_bbox() + face_bbox_center = (face_bbox_min + face_bbox_max) / 2 + face_bbox_size = face_bbox_max - face_bbox_min + print("face_bbox_center", face_bbox_center) + print("face_bbox_size", face_bbox_size) + + config_names = ["custom2", "custom_top", "custom_bottom", "custom_face", "random"] + + # normal rendering + for l in range(3): # 3 levels: all; no hair; no hair and no cloth + if l == 0: + pass + elif l == 1: + for i in hair_mats: + bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (0, 0, 0, 0) + elif l == 2: + for i in cloth_mats: + bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (0, 0, 0, 0) + + for j in range(5): # 5 track + config = configs[config_names[j]] + if "render_num" in config: + new_num_renders = config["render_num"] + else: + new_num_renders = num_renders + + for i in range(new_num_renders): + camera_dist = 1.4 + if config_names[j] == "custom_face": + camera_dist = 0.6 + if i not in [0, 1, 2, 6, 7]: + continue + t = i / num_renders + elevation_range = config["elevation_range"] + init_elevation = elevation_range[0] + # set camera + camera = place_camera( + t, + camera_pose_mode=config["camera_pose"], + camera_dist=camera_dist, + rotate=config["rotate"], + elevation=init_elevation, + camera_offset=face_bbox_center if config_names[j] == "custom_face" else 0.0, + idx=i + ) + + # set camera to ortho + bpy.data.objects["Camera"].data.type = 'ORTHO' + bpy.data.objects["Camera"].data.ortho_scale = 1.2 if config_names[j] != "custom_face" else np.max(face_bbox_size) * 1.2 + + # render the image + render_path = os.path.join(output_dir, f"{(i + j * 100 + l * 1000):05}.png") + scene.render.filepath = render_path + setup_nodes(render_path) + bpy.ops.render.render(write_still=True) + + # save camera RT matrix + rt_matrix = get_3x4_RT_matrix_from_blender(camera) + rt_matrix_path = os.path.join(output_dir, f"{(i + j * 100 + l * 1000):05}.npy") + np.save(rt_matrix_path, rt_matrix) + + for channel_name in ["r", "g", "b", "a", "depth", "normal"]: + sub_dir = f"{render_path}_{channel_name}" + if channel_name in ['r', 'g', 'b']: + # remove path + shutil.rmtree(sub_dir) + continue + + image_path = os.path.join(sub_dir, os.listdir(sub_dir)[0]) + name, ext = os.path.splitext(render_path) + if channel_name == "a": + os.rename(image_path, f"{name}_{channel_name}.png") + elif channel_name == 'depth': + os.rename(image_path, f"{name}_{channel_name}.exr") + elif channel_name == "normal": + os.rename(image_path, f"{name}_{channel_name}.exr") + else: + os.remove(image_path) + + os.removedirs(sub_dir) + + # reset + for i in hair_mats: + bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (1, 1, 1, 1) + for i in cloth_mats: + bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (1, 1, 1, 1) + + # switch to semantic rendering + for i in hair_mats: + assign_color(i, [1.0, 0.0, 0.0]) + for i in cloth_mats: + assign_color(i, [0.0, 0.0, 1.0]) + for i in face_mats: + assign_color(i, [0.0, 1.0, 1.0]) + if any( ii in i for ii in ['Eyeline', 'Eyelash', 'Brow', 'Highlight'] ): + assign_transparency(i, 0.0) + for i in body_mats: + assign_color(i, [0.0, 1.0, 0.0]) + + for l in range(3): # 3 levels: all; no hair; no hair and no cloth + if l == 0: + pass + elif l == 1: + for i in hair_mats: + bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (0, 0, 0, 0) + elif l == 2: + for i in cloth_mats: + bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (0, 0, 0, 0) + for j in range(5): # 5 track + config = configs[config_names[j]] + if "render_num" in config: + new_num_renders = config["render_num"] + else: + new_num_renders = num_renders + + for i in range(new_num_renders): + camera_dist = 1.4 + if config_names[j] == "custom_face": + camera_dist = 0.6 + if i not in [0, 1, 2, 6, 7]: + continue + t = i / num_renders + elevation_range = config["elevation_range"] + init_elevation = elevation_range[0] + # set camera + camera = place_camera( + t, + camera_pose_mode=config["camera_pose"], + camera_dist=camera_dist, + rotate=config["rotate"], + elevation=init_elevation, + camera_offset=face_bbox_center if config_names[j] == "custom_face" else 0.0, + idx=i + ) + + # set camera to ortho + bpy.data.objects["Camera"].data.type = 'ORTHO' + bpy.data.objects["Camera"].data.ortho_scale = 1.2 if config_names[j] != "custom_face" else np.max(face_bbox_size) * 1.2 + + # render the image + render_path = os.path.join(output_dir, f"{(i + j * 100 + l * 1000):05}_semantic.png") + scene.render.filepath = render_path + setup_nodes_semantic(render_path) + bpy.ops.render.render(write_still=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--object_path", + type=str, + required=True, + help="Path to the object file", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to the directory where the rendered images and metadata will be saved.", + ) + parser.add_argument( + "--engine", + type=str, + default="BLENDER_EEVEE", + choices=["CYCLES", "BLENDER_EEVEE"], + ) + parser.add_argument( + "--only_northern_hemisphere", + action="store_true", + help="Only render the northern hemisphere of the object.", + default=False, + ) + parser.add_argument( + "--num_renders", + type=int, + default=8, + help="Number of renders to save of the object.", + ) + argv = sys.argv[sys.argv.index("--") + 1 :] + args = parser.parse_args(argv) + + context = bpy.context + scene = context.scene + render = scene.render + + # Set render settings + render.engine = args.engine + render.image_settings.file_format = "PNG" + render.image_settings.color_mode = "RGB" + render.resolution_x = 1024 + render.resolution_y = 1024 + render.resolution_percentage = 100 + + # Set EEVEE settings + scene.eevee.taa_render_samples = 64 + scene.eevee.use_taa_reprojection = True + + # Set cycles settings + scene.cycles.device = "GPU" + scene.cycles.samples = 128 + scene.cycles.diffuse_bounces = 9 + scene.cycles.glossy_bounces = 9 + scene.cycles.transparent_max_bounces = 9 + scene.cycles.transmission_bounces = 9 + scene.cycles.filter_width = 0.01 + scene.cycles.use_denoising = True + scene.render.film_transparent = True + bpy.context.preferences.addons["cycles"].preferences.get_devices() + bpy.context.preferences.addons[ + "cycles" + ].preferences.compute_device_type = "CUDA" # or "OPENCL" + bpy.context.scene.view_layers["ViewLayer"].use_pass_z = True + + bpy.context.view_layer.use_pass_normal = True + render.image_settings.color_depth = "16" + bpy.context.scene.use_nodes = True + + # Render the images + render_object( + object_file=args.object_path, + num_renders=args.num_renders, + only_northern_hemisphere=args.only_northern_hemisphere, + output_dir=args.output_dir, + ) diff --git a/blender/distributed_uniform_lrm.py b/blender/distributed_uniform_lrm.py new file mode 100755 index 0000000000000000000000000000000000000000..e8719d854a333041b50f4309e0c5149ac27dbf9c --- /dev/null +++ b/blender/distributed_uniform_lrm.py @@ -0,0 +1,122 @@ +import json +import multiprocessing +import subprocess +import time +from dataclasses import dataclass +import os +import tyro +import concurrent.futures +@dataclass +class Args: + workers_per_gpu: int + """number of workers per gpu""" + num_gpus: int = 8 + """number of gpus to use. -1 means all available gpus""" + input_dir: str + save_dir: str + engine: str = "BLENDER_EEVEE" + + +def check_already_rendered(save_path): + if not os.path.exists(os.path.join(save_path, '02419_semantic.png')): + return False + return True + +def process_file(file): + if not check_already_rendered(file[1]): + return file + return None + +def worker(queue, count, gpu): + while True: + try: + item = queue.get() + if item is None: + queue.task_done() + break + data_path, save_path, engine, log_name = item + print(f"Processing: {data_path} on GPU {gpu}") + start = time.time() + if check_already_rendered(save_path): + queue.task_done() + print('========', item, 'rendered', '========') + continue + else: + os.makedirs(save_path, exist_ok=True) + command = (f"export DISPLAY=:0.{gpu} &&" + f" CUDA_VISIBLE_DEVICES={gpu} " + f" blender -b -P blender_lrm_script.py --" + f" --object_path {data_path} --output_dir {save_path} --engine {engine}") + + try: + subprocess.run(command, shell=True, timeout=3600, check=True) + count.value += 1 + end = time.time() + with open(log_name, 'a') as f: + f.write(f'{end - start}\n') + except subprocess.CalledProcessError as e: + print(f"Subprocess error processing {item}: {e}") + except subprocess.TimeoutExpired as e: + print(f"Timeout expired processing {item}: {e}") + except Exception as e: + print(f"Error processing {item}: {e}") + finally: + queue.task_done() + + except Exception as e: + print(f"Error processing {item}: {e}") + queue.task_done() + + +if __name__ == "__main__": + args = tyro.cli(Args) + queue = multiprocessing.JoinableQueue() + count = multiprocessing.Value("i", 0) + log_name = f'time_log_{args.workers_per_gpu}_{args.num_gpus}_{args.engine}.txt' + + if args.num_gpus == -1: + result = subprocess.run(['nvidia-smi', '--list-gpus'], stdout=subprocess.PIPE) + output = result.stdout.decode('utf-8') + args.num_gpus = output.count('GPU') + + files = [] + + for group in [ str(i) for i in range(10) ]: + for folder in os.listdir(f'{args.input_dir}/{group}'): + filename = f'{args.input_dir}/{group}/{folder}/{folder}.vrm' + outputdir = f'{args.save_dir}/{group}/{folder}' + files.append([filename, outputdir]) + + # sorted the files + files = sorted(files, key=lambda x: x[0]) + + # Use ThreadPoolExecutor for parallel processing + with concurrent.futures.ThreadPoolExecutor() as executor: + # Map the process_file function to the files + results = list(executor.map(process_file, files)) + + # Filter out None values from the results + unprocess_files = [file for file in results if file is not None] + + # Print the number of unprocessed files and the split ID + print(f'Unprocessed files: {len(unprocess_files)}') + + # Start worker processes on each of the GPUs + for gpu_i in range(args.num_gpus): + for worker_i in range(args.workers_per_gpu): + worker_i = gpu_i * args.workers_per_gpu + worker_i + process = multiprocessing.Process( + target=worker, args=(queue, count, gpu_i) + ) + process.daemon = True + process.start() + + for file in unprocess_files: + queue.put((file[0], file[1], args.engine, log_name)) + + # Add sentinels to the queue to stop the worker processes + for i in range(args.num_gpus * args.workers_per_gpu * 10): + queue.put(None) + # Wait for all tasks to be completed + queue.join() + end = time.time() diff --git a/blender/install_addon.py b/blender/install_addon.py new file mode 100755 index 0000000000000000000000000000000000000000..245f7db99ea4f431fc30b09955d74371d7654c64 --- /dev/null +++ b/blender/install_addon.py @@ -0,0 +1,15 @@ +import bpy +import sys + +def install_addon(addon_path): + bpy.ops.preferences.addon_install(filepath=addon_path) + bpy.ops.preferences.addon_enable(module=addon_path.split('/')[-1].replace('.py', '').replace('.zip', '')) + bpy.ops.wm.save_userpref() + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: blender --background --python install_addon.py -- ") + sys.exit(1) + + addon_path = sys.argv[-1] + install_addon(addon_path) diff --git a/canonicalize/__init__.py b/canonicalize/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/canonicalize/models/attention.py b/canonicalize/models/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..e4be16c609d63b2d2f9541e124c03cb572dfc736 --- /dev/null +++ b/canonicalize/models/attention.py @@ -0,0 +1,344 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm + +from einops import rearrange, repeat + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_attn_temp: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_attn_temp = use_attn_temp, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_attn_temp: bool = False + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_attn_temp = use_attn_temp + # SC-Attn + self.attn1 = SparseCausalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + if self.use_attn_temp: + self.attn_temp = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + #self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + if self.use_attn_temp: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_full_attn=True): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + # query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length) + if not use_full_attn: + key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) + else: + # key = torch.cat([key[:, [0] * video_length], key[:, [1] * video_length], key[:, [2] * video_length], key[:, [3] * video_length]], dim=2) + key_video_length = [key[:, [i] * video_length] for i in range(video_length)] + key = torch.cat(key_video_length, dim=2) + key = rearrange(key, "b f d c -> (b f) d c") + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length) + if not use_full_attn: + value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) + else: + # value = torch.cat([value[:, [0] * video_length], value[:, [1] * video_length], value[:, [2] * video_length], value[:, [3] * video_length]], dim=2) + value_video_length = [value[:, [i] * video_length] for i in range(video_length)] + value = torch.cat(value_video_length, dim=2) + value = rearrange(value, "b f d c -> (b f) d c") + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states diff --git a/canonicalize/models/imageproj.py b/canonicalize/models/imageproj.py new file mode 100755 index 0000000000000000000000000000000000000000..63e20527154594ef7a207b81c6520af2b07b8e50 --- /dev/null +++ b/canonicalize/models/imageproj.py @@ -0,0 +1,118 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math + +import torch +import torch.nn as nn + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) \ No newline at end of file diff --git a/canonicalize/models/refunet.py b/canonicalize/models/refunet.py new file mode 100755 index 0000000000000000000000000000000000000000..361bf0fe75be622d08cc1c45e005054056ed1f9b --- /dev/null +++ b/canonicalize/models/refunet.py @@ -0,0 +1,127 @@ +import torch +from einops import rearrange +from typing import Any, Dict, Optional +from diffusers.utils.import_utils import is_xformers_available +from canonicalize.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor + + +class ReferenceOnlyAttnProc(torch.nn.Module): + def __init__( + self, + chained_proc, + enabled=False, + name=None + ) -> None: + super().__init__() + self.enabled = enabled + self.chained_proc = chained_proc + self.name = name + + def __call__( + self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, + mode="w", ref_dict: dict = None, is_cfg_guidance = False,num_views=4, + multiview_attention=True, + cross_domain_attention=False, + ) -> Any: + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if self.enabled: + if mode == 'w': + ref_dict[self.name] = encoder_hidden_states + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=1, + multiview_attention=False, + cross_domain_attention=False,) + elif mode == 'r': + encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views) + if self.name in ref_dict: + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views, + multiview_attention=False, + cross_domain_attention=False,) + elif mode == 'm': + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) + elif mode == 'n': + encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views) + encoder_hidden_states = torch.cat([encoder_hidden_states], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views, + multiview_attention=False, + cross_domain_attention=False,) + else: + assert False, mode + else: + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) + return res + +class RefOnlyNoisedUNet(torch.nn.Module): + def __init__(self, unet, train_sched, val_sched) -> None: + super().__init__() + self.unet = unet + self.train_sched = train_sched + self.val_sched = val_sched + + unet_lora_attn_procs = dict() + for name, _ in unet.attn_processors.items(): + if is_xformers_available(): + default_attn_proc = XFormersMVAttnProcessor() + else: + default_attn_proc = MVAttnProcessor() + unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( + default_attn_proc, enabled=name.endswith("attn1.processor"), name=name) + + self.unet.set_attn_processor(unet_lora_attn_procs) + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): + if is_cfg_guidance: + encoder_hidden_states = encoder_hidden_states[1:] + class_labels = class_labels[1:] + self.unet( + noisy_cond_lat, timestep, + encoder_hidden_states=encoder_hidden_states, + class_labels=class_labels, + cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), + **kwargs + ) + + def forward( + self, sample, timestep, encoder_hidden_states, class_labels=None, + *args, cross_attention_kwargs, + down_block_res_samples=None, mid_block_res_sample=None, + **kwargs + ): + cond_lat = cross_attention_kwargs['cond_lat'] + is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) + noise = torch.randn_like(cond_lat) + if self.training: + noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) + noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) + else: + noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) + noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) + ref_dict = {} + self.forward_cond( + noisy_cond_lat, timestep, + encoder_hidden_states, class_labels, + ref_dict, is_cfg_guidance, **kwargs + ) + weight_dtype = self.unet.dtype + return self.unet( + sample, timestep, + encoder_hidden_states, *args, + class_labels=class_labels, + cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ] if down_block_res_samples is not None else None, + mid_block_additional_residual=( + mid_block_res_sample.to(dtype=weight_dtype) + if mid_block_res_sample is not None else None + ), + **kwargs + ) \ No newline at end of file diff --git a/canonicalize/models/resnet.py b/canonicalize/models/resnet.py new file mode 100755 index 0000000000000000000000000000000000000000..dece4326514c4da95868ae4838bc48383f2bbb7b --- /dev/null +++ b/canonicalize/models/resnet.py @@ -0,0 +1,209 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, :, None, None].permute(0,2,1,3,4) + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/canonicalize/models/transformer_mv2d.py b/canonicalize/models/transformer_mv2d.py new file mode 100755 index 0000000000000000000000000000000000000000..56eaf39375c83d7e686de4dc1b9baf301bf77550 --- /dev/null +++ b/canonicalize/models/transformer_mv2d.py @@ -0,0 +1,976 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +try: + from diffusers.utils import maybe_allow_in_graph +except: + from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.import_utils import is_xformers_available + +from einops import rearrange +import pdb +import random + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@dataclass +class TransformerMV2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class TransformerMV2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + num_views: int = 1, + joint_attention: bool=False, + joint_attention_twice: bool=False, + multiview_attention: bool=True, + cross_domain_attention: bool=False + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + else: + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicMVTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + else: + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return TransformerMV2DModelOutput(sample=output) + + +@maybe_allow_in_graph +class BasicMVTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool = False + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + self.multiview_attention = multiview_attention + self.cross_domain_attention = cross_domain_attention + self.attn1 = CustomAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=MVAttnProcessor() + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + self.num_views = num_views + + self.joint_attention = joint_attention + + if self.joint_attention: + # Joint task -Attn + self.attn_joint = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + nn.init.zeros_(self.attn_joint.to_out[0].weight.data) + self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + + self.joint_attention_twice = joint_attention_twice + + if self.joint_attention_twice: + print("joint twice") + # Joint task -Attn + self.attn_joint_twice = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data) + self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + assert attention_mask is None # not supported yet + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + num_views=self.num_views, + multiview_attention=self.multiview_attention, + cross_domain_attention=self.cross_domain_attention, + **cross_attention_kwargs, + ) + + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # joint attention twice + if self.joint_attention_twice: + norm_hidden_states = ( + self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states) + ) + hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + if self.joint_attention: + norm_hidden_states = ( + self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states) + ) + hidden_states = self.attn_joint(norm_hidden_states) + hidden_states + + return hidden_states + + +class CustomAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersMVAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + + +class CustomJointAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersJointAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + +class MVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1, + multiview_attention=True + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # multi-view self-attention + if multiview_attention: + if num_views <= 6: + # after use xformer; possible to train with 6 views + # key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + # value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + key = rearrange(key, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) + value = rearrange(value, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) + + else:# apply sparse attention + raise NotImplementedError("Sparse attention is not implemented yet.") + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class XFormersMVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1., + multiview_attention=True, + cross_domain_attention=False, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key_raw = attn.to_k(encoder_hidden_states) + value_raw = attn.to_v(encoder_hidden_states) + + # multi-view self-attention + if multiview_attention: + key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + + if cross_domain_attention: + # memory efficient, cross domain attention + key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2) + key_cross = torch.concat([key_1, key_0], dim=0) + value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c + key = torch.cat([key, key_cross], dim=1) + value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c + else: + # print("don't use multiview attention.") + key = key_raw + value = value_raw + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + + +class XFormersJointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value, dim=0, chunks=2) + key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c + value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c + key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c + value = torch.cat([value]*2, dim=0) # (2 b t) 2d c + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class JointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value, dim=0, chunks=2) + key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c + value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c + key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c + value = torch.cat([value]*2, dim=0) # (2 b t) 2d c + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/canonicalize/models/unet.py b/canonicalize/models/unet.py new file mode 100755 index 0000000000000000000000000000000000000000..594482d4e666d6ac937b3e7d3fe65784b37f708e --- /dev/null +++ b/canonicalize/models/unet.py @@ -0,0 +1,475 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import json + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from .resnet import InflatedConv3d + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_attn_temp: bool = False, + camera_input_dim: int = 12, + camera_hidden_dim: int = 320, + camera_output_dim: int = 1280, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.camera_embedding = nn.Sequential( + nn.Linear(camera_input_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_attn_temp=use_attn_temp + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_attn_temp=use_attn_temp, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + camera_matrixs: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) #torch.Size([32, 1280]) + emb = torch.unsqueeze(emb, 1) + if camera_matrixs is not None: + cam_emb = self.camera_embedding(camera_matrixs) + emb = emb.repeat(1,cam_emb.shape[1],1) + emb = emb + cam_emb + + if self.class_embedding is not None: + if class_labels is not None: + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + class_emb = self.class_embedding(class_labels) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + from diffusers.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME + + import safetensors + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + else: + state_dict = safetensors.torch.load_file(model_file, device="cpu") + else: + state_dict = torch.load(model_file, map_location="cpu") + + for k, v in model.state_dict().items(): + if '_temp.' in k or 'camera_embedding' in k or 'class_embedding' in k: + state_dict.update({k: v}) + for k in list(state_dict.keys()): + if 'camera_embedding_' in k: + v = state_dict.pop(k) + model.load_state_dict(state_dict) + + return model \ No newline at end of file diff --git a/canonicalize/models/unet_blocks.py b/canonicalize/models/unet_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..ac9b78d42fbd35261c23775e54cfa88f6507925d --- /dev/null +++ b/canonicalize/models/unet_blocks.py @@ -0,0 +1,596 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py + +import torch +from torch import nn + +# from .attention import Transformer3DModel +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_attn_temp=False, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_attn_temp=use_attn_temp, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_attn_temp=False, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_attn_temp=use_attn_temp, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_attn_temp=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_attn_temp=use_attn_temp, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_attn_temp=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_attn_temp=use_attn_temp, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/canonicalize/models/unet_mv2d_blocks.py b/canonicalize/models/unet_mv2d_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..114ee2649231e73e1137f37c474a0e73dc000a55 --- /dev/null +++ b/canonicalize/models/unet_mv2d_blocks.py @@ -0,0 +1,924 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import is_torch_version, logging +# from diffusers.models.attention import AdaGroupNorm +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from diffusers.models.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D +from canonicalize.models.transformer_mv2d import TransformerMV2DModel + +from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D +from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, + num_views=1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool=False +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + # custom MV2D attention block + elif down_block_type == "CrossAttnDownBlockMV2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D") + return CrossAttnDownBlockMV2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, + num_views=1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool=False +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + # custom MV2D attention block + elif up_block_type == "CrossAttnUpBlockMV2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D") + return CrossAttnUpBlockMV2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlockMV2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool=False + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + ) + else: + raise NotImplementedError + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlockMV2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool=False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + ) + else: + raise NotImplementedError + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + if num_views == 4: + self.gradient_checkpointing = False + else: + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + # hidden_states = attn( + # hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # cross_attention_kwargs=cross_attention_kwargs, + # attention_mask=attention_mask, + # encoder_attention_mask=encoder_attention_mask, + # return_dict=False, + # )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnDownBlockMV2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool=False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + ) + else: + raise NotImplementedError + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + if num_views == 4: + self.gradient_checkpointing = False + else: + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + diff --git a/canonicalize/models/unet_mv2d_condition.py b/canonicalize/models/unet_mv2d_condition.py new file mode 100755 index 0000000000000000000000000000000000000000..ce21ab09a76b04c681063bfda02a65d50ed0d9aa --- /dev/null +++ b/canonicalize/models/unet_mv2d_condition.py @@ -0,0 +1,1502 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import os + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange + + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, +) +from diffusers.utils import ( + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_model_file, + deprecate, + is_accelerate_available, + is_torch_version, + logging, +) +from diffusers import __version__ +from canonicalize.models.unet_mv2d_blocks import ( + CrossAttnDownBlockMV2D, + CrossAttnUpBlockMV2D, + UNetMidBlockMV2DCrossAttn, + get_down_block, + get_up_block, +) +from diffusers.models.attention_processor import Attention, AttnProcessor +from diffusers.utils.import_utils import is_xformers_available +from canonicalize.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor +from canonicalize.models.refunet import ReferenceOnlyAttnProc + +from huggingface_hub.constants import HF_HUB_CACHE +from diffusers.utils.hub_utils import HF_HUB_OFFLINE + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetMV2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + +class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool = False, + camera_input_dim: int = 12, + camera_hidden_dim: int = 320, + camera_output_dim: int = 1280, + + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.camera_embedding = nn.Sequential( + nn.Linear(camera_input_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + # custom MV2D attention block + elif mid_block_type == "UNetMidBlockMV2DCrossAttn": + self.mid_block = UNetMidBlockMV2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + camera_matrixs: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNetMV2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + if camera_matrixs is not None: + emb = torch.unsqueeze(emb, 1) + cam_emb = self.camera_embedding(camera_matrixs) + emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280]) + emb = emb + cam_emb + emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1]) + + aug_emb = None + + if self.class_embedding is not None and class_labels is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2]) + sample = self.conv_in(sample) + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNetMV2DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + camera_embedding_type: str, num_views: int, sample_size: int, + zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, + projection_class_embeddings_input_dim: int=6, joint_attention: bool = False, + joint_attention_twice: bool = False, multiview_attention: bool = True, + cross_domain_attention: bool = False, + in_channels: int = 8, out_channels: int = 4, local_crossattn=False, + **kwargs + ): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", HF_HUB_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + # if use_safetensors and not is_safetensors_available(): + # raise ValueError( + # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + # ) + + allow_pickle = False + if use_safetensors is None: + # use_safetensors = is_safetensors_available() + use_safetensors = False + allow_pickle = True + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + user_agent=user_agent, + **kwargs, + ) + + # modify config + config["_class_name"] = cls.__name__ + config['in_channels'] = in_channels + config['out_channels'] = out_channels + config['sample_size'] = sample_size # training resolution + config['num_views'] = num_views + config['joint_attention'] = joint_attention + config['joint_attention_twice'] = joint_attention_twice + config['multiview_attention'] = multiview_attention + config['cross_domain_attention'] = cross_domain_attention + config["down_block_types"] = [ + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D" + ] + config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" + config["up_block_types"] = [ + "UpBlock2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D" + ] + config['class_embed_type'] = 'projection' + if camera_embedding_type == 'e_de_da_sincos': + config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6 + else: + raise NotImplementedError + + # load model + model_file = None + if from_flax: + raise NotImplementedError + else: + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + model = cls.from_config(config, **unused_kwargs) + if local_crossattn: + unet_lora_attn_procs = dict() + for name, _ in model.attn_processors.items(): + if not name.endswith("attn1.processor"): + default_attn_proc = AttnProcessor() + elif is_xformers_available(): + default_attn_proc = XFormersMVAttnProcessor() + else: + default_attn_proc = MVAttnProcessor() + unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( + default_attn_proc, enabled=name.endswith("attn1.processor"), name=name + ) + model.set_attn_processor(unet_lora_attn_procs) + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + conv_in_weight = state_dict['conv_in.weight'] + conv_out_weight = state_dict['conv_out.weight'] + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=True, + ) + if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_in.weight.data[:,:4] = conv_in_weight + + # whether to place all zero to new layers? + if zero_init_conv_in: + model.conv_in.weight.data[:,4:] = 0. + + if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_out.weight.data[:,:4] = conv_out_weight + if out_channels == 8: # copy for the last 4 channels + model.conv_out.weight.data[:, 4:] = conv_out_weight + + if zero_init_camera_projection: + for p in model.class_embedding.parameters(): + torch.nn.init.zeros_(p) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model_2d( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + diff --git a/canonicalize/models/unet_mv2d_ref.py b/canonicalize/models/unet_mv2d_ref.py new file mode 100755 index 0000000000000000000000000000000000000000..c94c09ed1f1d11c22f68b0a760dbeb9cfd0b9e4f --- /dev/null +++ b/canonicalize/models/unet_mv2d_ref.py @@ -0,0 +1,1543 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import os + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange + + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.lora import LoRALinearLayer + +from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, +) +from diffusers.utils import ( + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_model_file, + deprecate, + is_accelerate_available, + is_torch_version, + logging, +) +from diffusers import __version__ +from canonicalize.models.unet_mv2d_blocks import ( + CrossAttnDownBlockMV2D, + CrossAttnUpBlockMV2D, + UNetMidBlockMV2DCrossAttn, + get_down_block, + get_up_block, +) +from diffusers.models.attention_processor import Attention, AttnProcessor +from diffusers.utils.import_utils import is_xformers_available +from canonicalize.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor +from canonicalize.models.refunet import ReferenceOnlyAttnProc + +from huggingface_hub.constants import HF_HUB_CACHE +from diffusers.utils.hub_utils import HF_HUB_OFFLINE + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetMV2DRefOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + +class Identity(torch.nn.Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + Examples:: + + >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 20]) + + """ + def __init__(self, scale=None, *args, **kwargs) -> None: + super(Identity, self).__init__() + + def forward(self, input, *args, **kwargs): + return input + + + +class _LoRACompatibleLinear(nn.Module): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self): + pass + + def _unfuse_lora(self): + pass + + def forward(self, hidden_states, scale=None, lora_scale: int = 1): + return hidden_states + +class UNetMV2DRefModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool = False, + camera_input_dim: int = 12, + camera_hidden_dim: int = 320, + camera_output_dim: int = 1280, + + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.camera_embedding = nn.Sequential( + nn.Linear(camera_input_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + # custom MV2D attention block + elif mid_block_type == "UNetMidBlockMV2DCrossAttn": + self.mid_block = UNetMidBlockMV2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()]) + self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None + self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity() + self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity() + self.up_blocks[3].attentions[2].proj_out = Identity() + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + camera_matrixs: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNetMV2DRefOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + if camera_matrixs is not None: + emb = torch.unsqueeze(emb, 1) + cam_emb = self.camera_embedding(camera_matrixs) + emb = emb.repeat(1,cam_emb.shape[1],1) + emb = emb + cam_emb + emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1]) + + aug_emb = None + + if self.class_embedding is not None and class_labels is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2]) + sample = self.conv_in(sample) + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + if not return_dict: + return (sample,) + + return UNetMV2DRefOutput(sample=sample) + + @classmethod + def from_pretrained_2d( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + camera_embedding_type: str, num_views: int, sample_size: int, + zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, + projection_class_embeddings_input_dim: int=6, joint_attention: bool = False, + joint_attention_twice: bool = False, multiview_attention: bool = True, + cross_domain_attention: bool = False, + in_channels: int = 8, out_channels: int = 4, local_crossattn=False, + **kwargs + ): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", HF_HUB_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + # if use_safetensors and not is_safetensors_available(): + # raise ValueError( + # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + # ) + + allow_pickle = False + if use_safetensors is None: + # use_safetensors = is_safetensors_available() + use_safetensors = False + allow_pickle = True + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + user_agent=user_agent, + **kwargs, + ) + + # modify config + config["_class_name"] = cls.__name__ + config['in_channels'] = in_channels + config['out_channels'] = out_channels + config['sample_size'] = sample_size # training resolution + config['num_views'] = num_views + config['joint_attention'] = joint_attention + config['joint_attention_twice'] = joint_attention_twice + config['multiview_attention'] = multiview_attention + config['cross_domain_attention'] = cross_domain_attention + config["down_block_types"] = [ + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D" + ] + config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" + config["up_block_types"] = [ + "UpBlock2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D" + ] + config['class_embed_type'] = 'projection' + if camera_embedding_type == 'e_de_da_sincos': + config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6 + else: + raise NotImplementedError + + # load model + model_file = None + if from_flax: + raise NotImplementedError + else: + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + model = cls.from_config(config, **unused_kwargs) + if local_crossattn: + unet_lora_attn_procs = dict() + for name, _ in model.attn_processors.items(): + if not name.endswith("attn1.processor"): + default_attn_proc = AttnProcessor() + elif is_xformers_available(): + default_attn_proc = XFormersMVAttnProcessor() + else: + default_attn_proc = MVAttnProcessor() + unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( + default_attn_proc, enabled=name.endswith("attn1.processor"), name=name + ) + model.set_attn_processor(unet_lora_attn_procs) + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + conv_in_weight = state_dict['conv_in.weight'] + conv_out_weight = state_dict['conv_out.weight'] + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=True, + ) + if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_in.weight.data[:,:4] = conv_in_weight + + # whether to place all zero to new layers? + if zero_init_conv_in: + model.conv_in.weight.data[:,4:] = 0. + + if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_out.weight.data[:,:4] = conv_out_weight + if out_channels == 8: # copy for the last 4 channels + model.conv_out.weight.data[:, 4:] = conv_out_weight + + if zero_init_camera_projection: + for p in model.class_embedding.parameters(): + torch.nn.init.zeros_(p) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model_2d( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + diff --git a/canonicalize/pipeline_canonicalize.py b/canonicalize/pipeline_canonicalize.py new file mode 100755 index 0000000000000000000000000000000000000000..1c0351d5977693300aa371c1f82fbe52e79c5a88 --- /dev/null +++ b/canonicalize/pipeline_canonicalize.py @@ -0,0 +1,518 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py + +import tqdm + +import inspect +from typing import Callable, List, Optional, Union +from dataclasses import dataclass + +import numpy as np +import torch + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer +import torchvision.transforms.functional as TF + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import deprecate, logging, BaseOutput + +from einops import rearrange + +from canonicalize.models.unet import UNet3DConditionModel +from torchvision.transforms import InterpolationMode + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class CanonicalizationPipeline(DiffusionPipeline): + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ref_unet = None, + feature_extractor=None, + image_encoder=None + ): + super().__init__() + self.ref_unet = ref_unet + self.feature_extractor = feature_extractor + self.image_encoder = image_encoder + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_vae_slicing(self): + self.vae.enable_slicing() + + def disable_vae_slicing(self): + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance, img_proj=None): + dtype = next(self.image_encoder.parameters()).dtype + + # image encoding + clip_image_mean = torch.as_tensor(self.feature_extractor.image_mean)[:,None,None].to(device, dtype=torch.float32) + clip_image_std = torch.as_tensor(self.feature_extractor.image_std)[:,None,None].to(device, dtype=torch.float32) + imgs_in_proc = TF.resize(image_pil, (self.feature_extractor.crop_size['height'], self.feature_extractor.crop_size['width']), interpolation=InterpolationMode.BICUBIC) + # do the normalization in float32 to preserve precision + imgs_in_proc = ((imgs_in_proc.float() - clip_image_mean) / clip_image_std).to(dtype) + if img_proj is None: + # (B*Nv, 1, 768) + image_embeddings = self.image_encoder(imgs_in_proc).image_embeds.unsqueeze(1) + # duplicate image embeddings for each generation per prompt, using mps friendly method + # Note: repeat differently from official pipelines + # B1B2B3B4 -> B1B2B3B4B1B2B3B4 + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(num_images_per_prompt, 1, 1) + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + else: + if do_classifier_free_guidance: + negative_image_proc = torch.zeros_like(imgs_in_proc) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + imgs_in_proc = torch.cat([negative_image_proc, imgs_in_proc]) + + image_embeds = image_encoder(imgs_in_proc, output_hidden_states=True).hidden_states[-2] + image_embeddings = img_proj(image_embeds) + + image_latents = self.vae.encode(image_pil* 2.0 - 1.0).latent_dist.mode() * self.vae.config.scaling_factor + + # Note: repeat differently from official pipelines + # B1B2B3B4 -> B1B2B3B4B1B2B3B4 + image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1) + return image_embeddings, image_latents + + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + video = self.vae.decode(latents).sample + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + class_labels = None, + prompt_ids = None, + unet_condition_type = None, + img_proj=None, + use_noise=True, + use_shifted_noise=False, + rescale = 0.7, + **kwargs, + ): + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + video_length = 1 + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + if isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + # Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings, image_latents = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance, img_proj=img_proj) #torch.Size([64, 1, 768]) torch.Size([64, 4, 32, 32]) + image_latents = rearrange(image_latents, "(b f) c h w -> b c f h w", f=1) #torch.Size([64, 4, 1, 32, 32]) + + # Encode input prompt + text_embeddings = self._encode_prompt( #torch.Size([64, 77, 768]) + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + video_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + latents_dtype = latents.dtype + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(tqdm.tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_cond = torch.randn_like(image_latents) + if use_noise: + cond_latents = self.scheduler.add_noise(image_latents, noise_cond, t) + else: + cond_latents = image_latents + cond_latent_model_input = torch.cat([cond_latents] * 2) if do_classifier_free_guidance else cond_latents + cond_latent_model_input = self.scheduler.scale_model_input(cond_latent_model_input, t) + + # predict the noise residual + # ref text condition + ref_dict = {} + if self.ref_unet is not None: + noise_pred_cond = self.ref_unet( + cond_latent_model_input, + t, + encoder_hidden_states=text_embeddings.to(torch.float32), + cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict) + ).sample.to(dtype=latents_dtype) + + # text condition for unet + text_embeddings_unet = text_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1,1) + text_embeddings_unet = rearrange(text_embeddings_unet, 'B Nv d c -> (B Nv) d c') + # image condition for unet + image_embeddings_unet = image_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1, 1) + image_embeddings_unet = rearrange(image_embeddings_unet, 'B Nv d c -> (B Nv) d c') + + encoder_hidden_states_unet_cond = image_embeddings_unet + + if self.ref_unet is not None: + noise_pred = self.unet( + latent_model_input.to(torch.float32), + t, + encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32), + cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance) + ).sample.to(dtype=latents_dtype) + else: + noise_pred = self.unet( + latent_model_input.to(torch.float32), + t, + encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32), + cross_attention_kwargs=dict(mode="n", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance) + ).sample.to(dtype=latents_dtype) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + if use_shifted_noise: + # Apply regular classifier-free guidance. + cfg = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # Calculate standard deviations. + std_pos = noise_pred_text.std([1,2,3], keepdim=True) + std_cfg = cfg.std([1,2,3], keepdim=True) + # Apply guidance rescale with fused operations. + factor = std_pos / std_cfg + factor = rescale * factor + (1 - rescale) + noise_pred = cfg * factor + else: + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred = rearrange(noise_pred, "(b f) c h w -> b c f h w", f=video_length) + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # Post-processing + video = self.decode_latents(latents) + + # Convert to tensor + if output_type == "tensor": + video = torch.from_numpy(video) + + return video diff --git a/canonicalize/util.py b/canonicalize/util.py new file mode 100755 index 0000000000000000000000000000000000000000..3c4daa393169f319084632b4a1f172d8ba981bfc --- /dev/null +++ b/canonicalize/util.py @@ -0,0 +1,128 @@ +import os +import imageio +import numpy as np +from typing import Union +import cv2 +import torch +import torchvision + +from tqdm import tqdm +from einops import rearrange + +def shifted_noise(betas, image_d=512, noise_d=256, shifted_noise=True): + alphas = 1 - betas + alphas_bar = torch.cumprod(alphas, dim=0) + d = (image_d / noise_d) ** 2 + if shifted_noise: + alphas_bar = alphas_bar / (d - (d - 1) * alphas_bar) + alphas_bar_sqrt = torch.sqrt(alphas_bar) + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / ( + alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, duration=1000/fps) + +def save_imgs_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + for i, x in enumerate(videos): + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + os.makedirs(os.path.dirname(path), exist_ok=True) + cv2.imwrite(os.path.join(path, f'view_{i}.png'), x[:,:,::-1]) + +def imgs_grid(videos: torch.Tensor, rescale=False, n_rows=4, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + image_list = [] + for i, x in enumerate(videos): + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + # image_list.append(x[:,:,::-1]) + image_list.append(x) + return image_list + +# DDIM Inversion +@torch.no_grad() +def init_prompt(prompt, pipeline): + uncond_input = pipeline.tokenizer( + [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt" + ) + uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] + text_input = pipeline.tokenizer( + [prompt], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + + return context + + +def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): + timestep, next_timestep = min( + timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep + alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod + alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + +def get_noise_pred_single(latents, t, context, unet): + noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + + +@torch.no_grad() +def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): + context = init_prompt(prompt, pipeline) + uncond_embeddings, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in tqdm(range(num_inv_steps)): + t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] + noise_pred = get_noise_pred_single(latent.to(torch.float32), t, cond_embeddings.to(torch.float32), pipeline.unet) + latent = next_step(noise_pred, t, latent, ddim_scheduler) + all_latent.append(latent) + return all_latent + + +@torch.no_grad() +def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): + ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) + return ddim_latents diff --git a/configs/canonicalization-infer.yaml b/configs/canonicalization-infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f62906ef7cdb0c06cb542bea8aa473ee35819b71 --- /dev/null +++ b/configs/canonicalization-infer.yaml @@ -0,0 +1,22 @@ +pretrained_model_path: "./ckpt/StdGEN-canonicalize-1024" + +validation: + guidance_scale: 5.0 + timestep: 40 + width_input: 640 + height_input: 1024 + use_inv_latent: False + +use_noise: False +unet_condition_type: image + +unet_from_pretrained_kwargs: + camera_embedding_type: 'e_de_da_sincos' + projection_class_embeddings_input_dim: 10 # modify + joint_attention: false # modify + num_views: 1 + sample_size: 96 + zero_init_conv_in: false + zero_init_camera_projection: false + in_channels: 4 + use_safetensors: true \ No newline at end of file diff --git a/configs/mesh-slrm-infer.yaml b/configs/mesh-slrm-infer.yaml new file mode 100755 index 0000000000000000000000000000000000000000..08b86c8639eb1457d5ed552c2271fedc6ed43ff3 --- /dev/null +++ b/configs/mesh-slrm-infer.yaml @@ -0,0 +1,25 @@ +model_config: + target: slrm.models.lrm_mesh.MeshSLRM + params: + encoder_feat_dim: 768 + encoder_freeze: false + encoder_model_name: facebook/dino-vitb16 + transformer_dim: 1024 + transformer_layers: 16 + transformer_heads: 16 + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 80 + rendering_samples_per_ray: 128 + grid_res_xy: 100 + grid_res_z: 150 + grid_scale_xy: 1.4 + grid_scale_z: 2.1 + is_ortho: false + lora_rank: 128 + + +infer_config: + model_path: ckpt/StdGEN-mesh-slrm.pth + texture_resolution: 1024 + render_resolution: 512 \ No newline at end of file diff --git a/data/test_list.json b/data/test_list.json new file mode 100644 index 0000000000000000000000000000000000000000..ac63066798e5294542ee481d6e942de6e8cfe156 --- /dev/null +++ b/data/test_list.json @@ -0,0 +1,111 @@ +[ + "7/3439809555813808357", + "2/6732152415572359482", + "6/6198244732386977066", + "7/7008911571585236777", + "8/8155498832525298838", + "1/2204149645140259881", + "0/1323933330222715340", + "7/1098644675621653787", + "9/6777209416978605329", + "1/1542224037528704351", + "0/8703316823295014690", + "3/5204013134706272913", + "0/6457137167414843850", + "2/6617574843151473382", + "8/7981152186026608038", + "1/4344590844740564561", + "2/7649110201056191442", + "2/1146977392849123402", + "2/2426517581512337892", + "7/2824689386300465357", + "6/2270010410433478366", + "3/3814323604952041013", + "9/8728960448674306769", + "7/1506365063811110387", + "5/5718924742282692475", + "1/1633099290949034671", + "5/8999640709832005845", + "5/720254657332917065", + "7/4357384925726277837", + "3/4227726538279421493", + "2/4382303856103217892", + "8/6632593566609006548", + "7/3749944138508065767", + "2/878764636138223992", + "5/8170908340955840135", + "6/4845695357833755236", + "1/2743140748471131991", + "1/5803218296084123071", + "6/9182882771353803536", + "5/5872666540206860925", + "4/9212223181352426964", + "5/3899312551169605935", + "0/7695929267562496220", + "7/3104109662674926717", + "8/2319063723115019838", + "6/8112121852475729956", + "9/5705939742315993109", + "1/6952166826280123421", + "0/6830091751476954110", + "2/8891263394100940152", + "3/8287958311266406833", + "9/8934151403263879299", + "7/730625960893750417", + "8/2007959965099676308", + "7/7110997111250638537", + "1/1910258394089325361", + "6/7538221091944098366", + "9/8509393563940760269", + "3/1981376850787241243", + "4/821179359686508964", + "6/2359248447840976906", + "2/5396219174677320232", + "7/4683457172478674257", + "8/1863701953709398218", + "9/910003033484940229", + "3/880320695540753593", + "0/990769530404275120", + "2/4551500513185396552", + "5/5015097855418058995", + "7/4896074338113329997", + "5/7306978321405535555", + "9/7776834385265136719", + "6/6631395994048613416", + "8/3757051138516476638", + "3/3283421712821668743", + "1/8144010044536474571", + "2/7876180780086370752", + "6/1647234603582341626", + "6/1341337037707864016", + "2/6302505551505574612", + "0/3465024955374919620", + "5/7900060151297927765", + "1/4675194210589373061", + "0/3282208207844657250", + "4/3240020585468727994", + "2/7833064532316643952", + "6/4790345485250053216", + "7/2935339105576984837", + "8/2599602859354916028", + "2/4769742243183930282", + "6/604217236327738596", + "4/5117485835686648194", + "0/1487097526635566140", + "4/3484530361677579674", + "3/8530544536064633943", + "7/4144922250519743927", + "9/2413192196654279969", + "2/1350971297625987822", + "5/6433334135280042785", + "7/6692827166906062907", + "8/4678213844371676838", + "9/262140445129918559", + "5/4188635875053572005", + "9/6950138434143075689", + "4/6953579337597168824", + "6/16762222989681526", + "0/8704380013906593380", + "0/6734578480501157450", + "1/8562961060475858791" +] \ No newline at end of file diff --git a/data/train_list.json b/data/train_list.json new file mode 100644 index 0000000000000000000000000000000000000000..69578f9738ee76847acd671952480353d6ac7c94 --- /dev/null +++ b/data/train_list.json @@ -0,0 +1,10704 @@ +[ + "9/7067719180331438579", + "9/873743095256255129", + "7/825918502829625037", + "3/355758713874488343", + "6/1430240832593469446", + "6/5348895794631984346", + "6/6363588157801306286", + "1/4133521169028589781", + "7/7062771409315115817", + "1/5294134435767282651", + "9/2411832157706392999", + "8/6439116528565341658", + "6/7333563959008422226", + "2/1530072300304794132", + "2/1393567942866003682", + "8/914623853898718908", + "2/7423364123038651862", + "8/7375042070748476938", + "9/2000518679581906489", + "1/2230583144806488771", + "4/8469815868270111854", + "3/254121803125341403", + "4/4730370512010041454", + "7/1324277971638096817", + "4/4031523998783585204", + "7/7329554934263169727", + "5/5391716625302759955", + "8/9129304411542992128", + "2/2441326254016319592", + "6/7794596601761437726", + "9/2219514691229167139", + "2/3867111405362911542", + "7/1813762583047261167", + "8/6646340902650705398", + "2/6393406335313104462", + "8/4452076253095336728", + "9/5770005216945444739", + "5/599586818355752295", + "1/4998261633157467151", + "4/4876344295380323034", + "4/1780107656666421854", + "6/4227327970096798056", + "1/7408854719268264911", + "1/2688949398072248391", + "1/4717045365677648271", + "8/3897756981079538848", + "1/4672482666430099191", + "4/3138045737141200124", + "1/8003709930821431811", + "8/7109703511132598", + "5/3316713785534135415", + "2/7236028209190114132", + "7/957786991190086737", + "3/8845080473614443393", + "5/4159929557075258165", + "6/8024220135013269896", + "5/3734418945126999675", + "0/7042656360297205580", + "1/8034382074004933801", + "9/3245768020775560939", + "9/2784954933697082559", + "9/919370715781188819", + "1/8910077933676264741", + "1/3665151837915064641", + "4/5437016270741417424", + "0/815760880245243370", + "2/8389081869026794822", + "5/527448473572660695", + "6/1267436427722557796", + "9/1392602010798970909", + "6/4172625736474347976", + "3/7885501204059835453", + "2/8963920132947456382", + "5/8893788138323287955", + "3/3472996924875914173", + "8/4041105008298983808", + "0/396973528195932890", + "2/4980685568250817012", + "3/7517601002174282343", + "7/3607306510433700757", + "2/5753123808320641332", + "3/8901310971744228793", + "9/1100718674890536469", + "9/8165252848931292889", + "4/5977406131606415184", + "3/8150768453308892493", + "4/5736484362380956424", + "4/1357896220782738384", + "0/3461017189180364510", + "1/2696752206612441081", + "5/3812763527134696315", + "0/4263959701169154120", + "8/438605442439682998", + "2/6650937494358500132", + "9/3633834117599565879", + "7/7511816175803131977", + "6/5278444475743501096", + "6/1436377050953662526", + "6/4356019403786899426", + "3/4767471923261329223", + "5/8402397133732160145", + "7/5463842245333619667", + "0/8935394681273503420", + "1/2357140593357843071", + "1/8558741977673664791", + "3/1337562090770316483", + "7/7521538037958820407", + "6/727783369030348816", + "8/4420497691628352138", + "0/3764481947787626140", + "2/274008431446817732", + "3/3437331346529574743", + "1/1878061556648112061", + "8/4506221522730674378", + "4/6225919731412999664", + "9/6587285740443054319", + "0/5620402198269073980", + "4/297399322277309154", + "1/5521101906961849191", + "3/6721198737295149193", + "7/6155112029839704977", + "0/2298748470490203300", + "7/6433016014128390337", + "3/741162389022603313", + "3/6237086427260515783", + "6/9017442908026448056", + "9/192406529517133079", + "2/4845725297462474492", + "0/8888368752953207510", + "8/76028455457744368", + "2/3060470529481827002", + "0/3110466464034847360", + "6/4229482759826764566", + "8/5573747370610735438", + "8/5031233638049905828", + "1/3594539615127358401", + "4/2299212724638591134", + "3/8492126058247535403", + "5/3529120802927635505", + "6/3063237634044846306", + "1/5531766162681345111", + "6/8158678668971960116", + "1/1409958574012593071", + "1/4679151594301325191", + "2/7303382683672708072", + "5/3992649821919072025", + "3/642495181627729033", + "8/2974468069774829968", + "2/4880286773680336632", + "8/7650655183255883758", + "7/725664129714175127", + "2/8732821369518478392", + "6/3651133555043553966", + "2/290867584957661592", + "3/5170722977896751143", + "8/889619533062992038", + "3/7988934640415616323", + "3/2027848636289130093", + "1/4156670572325187401", + "4/4810110767153555644", + "0/427396224867369330", + "4/3336891135895394264", + "5/4567144777158405315", + "7/5918178101084869747", + "0/8000799183243168710", + "0/4089379538525763820", + "3/3030599391575397053", + "8/4605108054728635788", + "8/2957919728017576238", + "7/9167640545484556527", + "0/5172858305116714710", + "0/7789520090253306900", + "4/7474182781387807274", + "7/2192356550381367147", + "5/6217212075357934575", + "7/6514969318740062237", + "9/8127365883744036949", + "4/50077920950054374", + "8/316129863396362338", + "1/5809053930203624151", + "3/6117868846290259573", + "4/7451391488998733524", + "5/4660293314540284925", + "8/2957585095093644238", + "3/101759313492762033", + "2/2291992557548649802", + "5/7202360085357564825", + "8/2458825067624946188", + "3/3436737565125838263", + "5/8926865230895540115", + "2/7527438108201888412", + "7/4684255223753092697", + "8/3706374324829416168", + "1/4964122100046459771", + "1/2719243977589910101", + "3/3920394029419623883", + "5/3081587203408232675", + "5/5894782232309926135", + "7/3721967170648584257", + "7/5626918627661825397", + "1/8224031502850773951", + "9/9022628047922815759", + "6/3493377998857535286", + "4/2765412972905311064", + "0/158495417346710430", + "1/6791394175182366541", + "4/6672696767364048174", + "2/8352466282351880232", + "9/689765440544731939", + "2/7120739048128891672", + "6/4948066213521280116", + "0/3532352608945117170", + "4/5303687892461981814", + "5/8979101851198088125", + "2/9154915791270579142", + "5/970018378754800245", + "0/8389224356776822560", + "1/8576216623061485151", + "6/3699034363763347666", + "7/409537938516786897", + "3/1985724458739161923", + "8/1809877326732037048", + "6/5691499296497596246", + "8/1140171187759744858", + "8/9096376762247732508", + "4/5465500553526305324", + "8/4281018339721199228", + "4/35741832170373584", + "1/2136800623119878871", + "8/4061973686341253608", + "3/3679583753966600653", + "1/6731867896282084321", + "3/2928897787632996243", + "3/7147626609772188943", + "9/2175904026559326029", + "9/9178504721633041619", + "2/5554820826471298982", + "5/4498009745895107115", + "9/8551859562215743779", + "1/621730251299475051", + "5/7073034385530877005", + "8/182243890626887468", + "6/3208382687804696816", + "3/7438204804185032333", + "1/8082442887688284001", + "6/7023985963271289266", + "2/4531735942934334672", + "9/8616693546066686269", + "7/3881308784029097797", + "2/5545655870938777262", + "6/2761492786021351496", + "6/5129654180087712736", + "3/3796195171591802013", + "7/8159364047384559137", + "0/1756245348131755400", + "0/7593464013742752040", + "3/1337121005608304633", + "9/5861354260529182869", + "7/529458178589970147", + "3/6225670290099542903", + "8/2788695870225939368", + "2/3606434015620076032", + "5/8237159930966857515", + "4/3365861068880734424", + "9/7754950319069611829", + "3/8269651625415476073", + "1/3503434101429144961", + "9/5637587012273736089", + "0/3058600652178443110", + "2/5338502896772309382", + "2/7611508635352485612", + "5/224905387865820015", + "7/99135446856012657", + "9/1515387038421478619", + "4/2112655397200058924", + "6/2942890196081285976", + "4/6070772456826090184", + "2/1381138388015572182", + "8/140061987942389418", + "8/4101582223093622348", + "8/6540529743471888288", + "9/2561565596329420769", + "3/4784885962249063213", + "4/1373493465952075244", + "2/4518729813653393582", + "3/8078143271408337163", + "3/9139777306075016573", + "6/1206545724687169806", + "5/2524020563594372145", + "8/637328495549646608", + "3/2213162180844639443", + "6/6787665184383868896", + "1/5984088369421540531", + "7/1921597549181717867", + "3/3542903054937827083", + "4/3902041218968964994", + "5/1343083580194350305", + "2/3468836247417891142", + "3/1509885226984854813", + "9/2655255625300067799", + "0/4232314507957073140", + "1/1487609098384792571", + "2/6785745410962077282", + "8/3686158670036437828", + "5/7468754363118423225", + "1/2784030517501305691", + "8/9194839211595613148", + "1/1916463211330544571", + "2/7860062650922853892", + "2/9129590785957257682", + "3/5324985756938222573", + "1/3612663339037528851", + "1/6656338785488204611", + "6/7968633713996319986", + "0/7177828059554291450", + "9/3773559868943126059", + "6/2301929386045394116", + "1/5413252724795497531", + "8/5697878507269623878", + "3/4785278436283981733", + "4/194507835336778444", + "4/1995929407271405304", + "8/2771118681792168668", + "0/7938976500746859490", + "5/3517397628069587455", + "1/1176527751832228681", + "9/1714097464871634189", + "8/2869209671257940388", + "8/5917979798289022328", + "3/3097247219040037143", + "2/8261153689482662852", + "0/4125010201198246410", + "2/455093405321602532", + "5/3110438585028306505", + "2/3578684916488743922", + "3/8257851689213055483", + "8/6685808611110226138", + "2/5359320933588529322", + "6/1122990510369203406", + "5/4502702882548871335", + "4/8831639097490830004", + "0/2772988205041791480", + "4/8248688506529059464", + "9/6956303378739320459", + "3/217682653012647533", + "4/2107066096198451634", + "9/971347010782588309", + "3/4523379039112367383", + "2/5637870368674562282", + "9/2067828027432579329", + "7/4328230255846348527", + "5/6857702705836239855", + "6/1789374794600248896", + "0/8367847544880899280", + "5/3583556082687070075", + "6/114919545267320626", + "8/1055179520364421598", + "0/2807605844198789530", + "6/5776441899930088826", + "2/7256892809225831632", + "1/6771037327264597731", + "0/2050558999678266090", + "2/4999733780350751742", + "4/8990791639611776404", + "9/8492132066180929509", + "5/2054435755200585555", + "7/6112922286473846657", + "0/6390463699285524330", + "2/8355707445555826132", + "0/6348552246717290040", + "9/9193477097029657179", + "0/8433370769545711120", + "3/1572933028859430703", + "3/6382454017732456063", + "7/7412329934701184207", + "7/9135220721814830497", + "0/3680544747505011980", + "0/7483180197742723680", + "4/1396696834246536844", + "8/5997942691764525088", + "3/3677549253602535113", + "8/3630000907668719768", + "8/7098424984987534238", + "3/8810286222995228723", + "8/6241866922434384818", + "0/2202490205573898200", + "3/217602163163488753", + "2/8449666785373054962", + "0/7678001411529189570", + "1/3628219351857081811", + "8/9200831262969696098", + "7/2311033727026071177", + "8/5491976365207423228", + "6/2365656634184303736", + "2/1506023877218826342", + "1/20822327561190851", + "2/9193248666869616492", + "8/7868186766488771198", + "1/1935072728590389891", + "2/4133281825102763502", + "1/4546037838387991631", + "6/2257092931053603636", + "3/3088158537340799423", + "1/3101303186252934231", + "0/3805344958290963880", + "2/1393035830671916012", + "6/9099801193072789576", + "3/2832080024523265553", + "3/538947081538161813", + "6/3351931332832305626", + "7/209133608536209577", + "1/6690541765826495011", + "7/345328220680777217", + "8/8377937553915552668", + "9/4147941192784392199", + "7/8670203922453239817", + "6/6692570637854314546", + "1/5008579611015443751", + "6/2352158623568928006", + "7/8362878631988874757", + "7/7281472612656368387", + "8/5876232804607189238", + "3/3824039659973193633", + "7/1257621784885787387", + "5/5177716416192716205", + "8/6472978587754762808", + "9/8518903751656604129", + "8/1945416999507519088", + "6/2631889239438017206", + "2/3051383117439063652", + "5/6544715578336440085", + "4/2614965118725858464", + "9/2357471818959334209", + "7/7278581668679727487", + "2/4407954046739923332", + "9/3303261255858677269", + "9/2287861475341351339", + "4/1962965152838305174", + "4/980699973061049924", + "6/6098544033609974466", + "5/4526718634015404455", + "9/338145597471314909", + "2/4248693116586578632", + "1/2915715444285138961", + "4/7860114013084220834", + "4/1561794585243959344", + "0/8262409658999138080", + "6/2056599967932504536", + "5/5345640411989821555", + "7/7122584232289718257", + "6/8615446718291193396", + "1/6835313259338533721", + "1/8375332592266344431", + "9/8145575584639437209", + "2/9058255447431993242", + "3/2187299772203613563", + "4/3763250162048057284", + "8/5641872447706021918", + "9/887908503868039579", + "5/9216127804556933575", + "5/1006799417544436825", + "4/477741859036429564", + "4/315703799847088574", + "6/6798320829604694906", + "6/3111910516783930686", + "6/2795582266426113636", + "2/372814291372602902", + "4/6794802128711664904", + "8/7950436561678162478", + "1/7931442559571353081", + "7/2672868711170877577", + "1/4829034243951135901", + "2/2898169452733875652", + "4/1550199843465167324", + "1/497800745961592741", + "9/3339018748782142589", + "0/832910598912862560", + "6/2775153294100565646", + "2/6471228000569636242", + "5/3788330656279862185", + "4/5209679646385665604", + "6/4513252687101535866", + "8/9010706568930739688", + "9/5157506821084303659", + "7/4957719108449168867", + "8/257353358459112778", + "9/5151379126907689829", + "4/4068466383107212244", + "7/1360625021277357397", + "2/2656801322705739352", + "1/1984110687261908061", + "2/8223311327295710812", + "2/5650817543803294882", + "6/7859398311643883886", + "4/307152399930335374", + "5/4241645710074544045", + "0/5858419181677172540", + "7/7507897818276356827", + "7/6398403871304022567", + "4/5107611009420246844", + "1/4205303764678706801", + "2/6037079455243663552", + "4/5141514256099581824", + "7/9033714307124454657", + "3/6014454049723139083", + "9/6609872357509818399", + "3/8263331353357540253", + "6/5278638334906713646", + "5/7363144399662688365", + "6/476394033340014696", + "8/8866545235857421588", + "3/2797375420881902793", + "6/2870830472618867246", + "0/4226983837021293760", + "3/5249300827034551223", + "8/7633196846578193288", + "3/5206484181994218153", + "5/1528567607744376425", + "7/7499382320577456477", + "0/1726526286747502760", + "7/1496515729782978877", + "1/1087029488860255151", + "3/4390370895584773043", + "9/9008963505354611889", + "2/5169376288869386002", + "7/7726970318262732057", + "1/4756724561242989141", + "0/1727775897434774650", + "6/1057727587374292386", + "7/2378213207043999317", + "1/7566616606526846951", + "5/293234545391521465", + "4/931880135123550914", + "0/8027831038494999240", + "9/1715158652427058909", + "3/3079870631047232533", + "6/6066293377887775436", + "9/3030754064190019129", + "7/3211474043337062707", + "1/4520754819153774991", + "4/1271805975978668114", + "6/3696026799347124846", + "8/1055664085489968178", + "4/2702583817442649884", + "6/5358469191922076256", + "8/342915425161979478", + "7/5014702847040237287", + "7/3383751442912063017", + "5/617662621671999285", + "1/1109132397976870611", + "3/1845521706304530923", + "7/6355195077645968207", + "3/3129898194394824093", + "1/8365002118131684481", + "9/4139640884928297519", + "7/420545608616777747", + "1/4286246354162142501", + "4/5701967324875172904", + "0/4213196292582719310", + "5/5475868033038452095", + "1/2497601091155757211", + "6/2604303000831561086", + "8/8128335963387728208", + "7/5425714201660345387", + "2/2577460917027827172", + "4/6746934255553156634", + "3/4414984268523586293", + "7/1382842979403050097", + "0/5721923735786686500", + "0/6013861093241849140", + "2/1634033091184592612", + "0/8635482204343521060", + "5/5639616634756226065", + "8/8857824844632446688", + "8/1309080047874941928", + "9/3921650526017068939", + "0/7605981428195413160", + "7/1554139888650241587", + "8/4413654419671973178", + "3/7627172707176428783", + "5/2559901703448533185", + "5/8131048420096630615", + "6/3967559449241678696", + "9/3091987662288626499", + "5/7966409527662614815", + "6/3225859174039254796", + "9/228726765229712049", + "3/7565269830152447563", + "5/8975107540634217495", + "8/3562507074179910498", + "6/1710986234242252256", + "6/535644384169592386", + "6/8441233130969242046", + "0/8074856770639515020", + "1/3127506230758392801", + "2/2658544303158784602", + "4/6332312833516801854", + "4/8956976806670161554", + "7/6184442959372966247", + "5/213764751511314985", + "1/9047009865154235611", + "6/132247343285110666", + "2/3859638009390151932", + "6/87742793581800166", + "5/7146486942739561315", + "5/2251093004334611505", + "7/4069527967281881297", + "9/3261062364087577569", + "9/7610018928284419599", + "9/8903656163433144589", + "2/1426429297451026542", + "3/7551061461117680333", + "7/2529593770095536447", + "7/654116924621177347", + "2/5342606075560991762", + "6/9054308763879343536", + "7/7701430418926637387", + "5/2556727474255527355", + "7/7037113803195503557", + "9/3366437818799848049", + "1/7623985574392900121", + "9/4485444783391813619", + "7/8610702594407553377", + "4/8017613268736820994", + "7/8865079099627632977", + "2/172905453922646782", + "3/3986309793164881793", + "3/7760018184999878843", + "5/313212166305968945", + "5/900773634328399835", + "1/2075581780753499891", + "2/3509155013473841442", + "3/4380549049974448183", + "0/3324097067782386040", + "8/6662209333202543618", + "6/8647538772428058486", + "1/4535824482964496111", + "3/3306209887297827263", + "3/2387660069509233433", + "2/7952184021385427232", + "4/6460358452295137724", + "4/836468045456903654", + "9/1091672971553348119", + "3/4066133015423582903", + "5/6676755074261288675", + "8/6442072963206770978", + "1/1357558172976932331", + "4/145317236772291854", + "0/1118161408009810020", + "8/5172975507488959028", + "9/5645345810029395079", + "8/312612646710716548", + "0/6661008364527195590", + "4/1593181122290030814", + "5/8494124561606161775", + "7/6617846191419192787", + "5/1221288937244077575", + "0/115573660714302690", + "4/2532489563255244034", + "5/1760202437370479045", + "9/1934090972924831379", + "9/2741397729458803659", + "2/8009708092039059222", + "4/5548724676419750234", + "9/7333366723041942269", + "9/8633169062231675729", + "3/1381185908427242333", + "0/7499482555770800250", + "0/6018826591592337230", + "3/2992661420849403543", + "9/7265492272999040449", + "8/2615637260793096228", + "4/4492648597588349674", + "8/4907494749134882538", + "5/1312119359604068365", + "4/5402901593733060504", + "0/5977158210136328630", + "2/4846585125076429152", + "4/6496131285111957324", + "7/4572977518803093287", + "8/8974337805841615138", + "1/3218395715266991341", + "1/5131548166288185421", + "6/1470796716115975746", + "0/4328188147093725850", + "2/8479244661782879712", + "1/8422935395567736691", + "0/5491317038403698370", + "5/8019349040524931445", + "0/5489213839380529610", + "8/4909018217288976468", + "9/6458544352115054239", + "2/7124712349118669092", + "6/5848344336231169116", + "7/254199666039116287", + "6/8186981109641174706", + "4/9012292347856087844", + "0/5062520653326223120", + "6/2334112060819294286", + "7/2589133693002318887", + "6/1583577535824123646", + "9/4821683784603591809", + "0/5213375051399188660", + "2/1431623522240256412", + "3/2907349599171434383", + "8/1870356226328135098", + "8/1914643652103556808", + "5/6658447238786234945", + "9/8970901442329492689", + "2/439542171618443362", + "6/422516858365252446", + "9/8546090583288720969", + "1/6166989665417885301", + "6/5107860010518530296", + "2/5028910112597032412", + "0/265693881831429600", + "8/7823032462254457808", + "4/3127878716793801594", + "7/1751817668843199697", + "2/2220977868644819842", + "8/6218892100068777008", + "9/5075973546910152779", + "8/683178160625021548", + "1/3464814985501340101", + "1/2658496460439839761", + "6/1902023197563859616", + "5/7503964918662995625", + "4/7911544648096987464", + "9/8227176227548908619", + "1/8590238486593976221", + "2/5579811459889406392", + "9/7784871601059176869", + "2/7676798979245675072", + "3/7209714471425190503", + "6/5753963511824055036", + "5/6997022447100807235", + "1/3465160762687355081", + "2/5680462154393945962", + "9/3580438195834899649", + "0/4214338010227341470", + "1/9020784299900229931", + "0/3867145843520076910", + "0/6026774473043542920", + "3/781151746456361753", + "5/8299895595917064605", + "5/5026047563776862715", + "7/4034168875723545477", + "2/572152403748192702", + "5/2217644973232202335", + "5/5021179383788655075", + "6/5246810828745951816", + "0/7792382672409042710", + "4/7581783230036554354", + "7/1341307979660293047", + "2/5490096827765287172", + "8/8691237556584156028", + "3/6519025107971781043", + "6/5010279197059293066", + "0/4144683421539045650", + "7/3344375804897637197", + "0/5845143173536083280", + "3/4402818084527065813", + "5/7749809699301067205", + "8/8272840543034416208", + "9/1268341833256029209", + "2/462745939168556072", + "2/6965689083998764272", + "9/5779503308116064929", + "9/5071157361770276439", + "0/6572170742754325140", + "3/3767746663193436593", + "0/2697323717157458900", + "7/3566876752844485807", + "5/8398407325939895485", + "5/5949079029581838085", + "7/8865762915706186787", + "3/8235938808750377033", + "1/3412750920836365461", + "4/1734059925327134384", + "3/6243093961287102583", + "5/8766372264643424095", + "2/6718698961891920252", + "1/1694721159670568101", + "5/5134431763197081015", + "1/1657123377436338671", + "6/4784654365891608026", + "2/3989684540884810712", + "3/6262098144992983123", + "0/2377045089208906860", + "3/4307435689630402783", + "9/7388358548245757009", + "2/3790243492610960372", + "9/7769571434795134879", + "1/9203467159313512831", + "4/4332299479194544324", + "0/6372901907607126600", + "2/4991623945989355392", + "3/3397950325695203673", + "1/2483992054347285101", + "9/997911887193589349", + "1/7216406401821063001", + "5/4911755947499683075", + "8/1257686576865589478", + "9/4975060221384999349", + "2/5542029714720682222", + "2/6731389463174195682", + "1/2127549029109151671", + "8/9062934487411623988", + "5/1074173863623961375", + "5/1630291826970164055", + "0/7246548013864359760", + "6/3982234919292026916", + "7/3092019083358331267", + "8/4407470502515266338", + "7/2271009628842692407", + "1/5426883811680894091", + "4/5906721519408979774", + "7/7912632717515473917", + "4/6064832325163610054", + "7/8011499053565131157", + "2/6100564161352433372", + "4/7352526686505303224", + "1/1401369168532460571", + "0/1212088555229009410", + "3/205399557937193293", + "4/2189497852238263304", + "6/8053167828895725496", + "3/3931452965897195073", + "9/6690322254844992069", + "6/3114013802125185076", + "5/5157132466188255845", + "1/8757726435660766791", + "8/2252077541816984128", + "3/2487864093800546473", + "2/2981913236014364082", + "0/1694004807779533410", + "3/357789660565409693", + "1/5370458882486076421", + "3/2784811169495498703", + "5/3171279309455981855", + "5/4536624621665720265", + "1/2904470182291003871", + "6/7414678896699535166", + "8/2377077670363360768", + "6/2527107308383594196", + "7/689280736643250547", + "1/2624607373990005371", + "3/3050752005740328003", + "8/8338743958694264618", + "3/6990115480480718923", + "8/4764399415918061348", + "0/3178045218507561080", + "0/4972353594554679670", + "3/7693548548476493353", + "0/7389974636498430490", + "4/2860734323858529414", + "4/554753089558606494", + "4/5759450759111269994", + "8/6785682041523404738", + "0/3806355874761573410", + "0/7158578614313000000", + "8/749131699767688908", + "9/2827153591664283859", + "5/177589812913393795", + "8/4373906176985557448", + "1/1486335323732609891", + "5/8911186847787229025", + "3/834535115800583053", + "3/1358675409050355913", + "8/7877854018110255028", + "0/8221040570619468360", + "3/5007385710038818423", + "6/3765963516852873916", + "1/1162567486754192651", + "1/312596298366943791", + "3/8053763459694545433", + "9/8196527089455772029", + "1/6863856027683363331", + "4/3991746878312368604", + "0/5615036675995875910", + "5/8512271703730839565", + "2/6833385602186659462", + "5/2918795826994770785", + "5/4599696502972910555", + "0/8312114910777916290", + "5/5411634974229363635", + "9/4122879212424600449", + "2/3772956412224632622", + "0/5195726811299704870", + "0/6757190673025037990", + "4/6657498104824700544", + "2/2735564432706549502", + "2/6632844840708282822", + "8/8471344480232775918", + "8/7006139036482061008", + "3/7152293501603062783", + "4/2968655506119680164", + "7/5264670331393928587", + "4/4624889486034598284", + "4/2979594526490784874", + "3/5638624753904224953", + "4/3467084589621137254", + "8/9155123254092687868", + "1/5560338084518139051", + "4/4423006740960699034", + "5/2559391626303297935", + "6/4960184383506621536", + "5/6507951763448284425", + "6/3766353507642040246", + "9/3816076484247208299", + "3/5000430812650603453", + "1/3319406557507521511", + "3/2296044519663984363", + "8/4291797465670253048", + "3/6598916172483925393", + "6/3697394808599374636", + "6/5327550832512007836", + "9/5510333403968493609", + "9/3755918142956121209", + "5/7636888025049801295", + "7/199855900740785337", + "6/7284474866464103606", + "0/8097056750985053470", + "2/4036652283010283462", + "5/7606676930613455215", + "4/1750342272405303024", + "2/5161260273641014352", + "5/378474000147856745", + "4/5906952404122853274", + "2/3446496590917567302", + "5/7214488984524110685", + "3/7960013553573330973", + "3/1672192252304467223", + "9/6197499465136782769", + "6/3078659758600513106", + "5/3408127186773431845", + "6/2754637872959534406", + "6/3426416987514552176", + "4/7927698758430086614", + "4/7077639655551445194", + "3/8630486461848266683", + "6/2202648087060415116", + "8/4958500762392133018", + "7/2902859845798083487", + "7/3612774442973836597", + "0/9020974662553621940", + "1/7448500860388557151", + "1/4721008145345871951", + "2/8873239130560974462", + "6/3439971856177076856", + "1/4371611380109090911", + "9/8266935361247384019", + "1/8077142887631869301", + "8/2311830250306814708", + "3/6908070641437939853", + "4/3599584599180516124", + "1/7700777981517783921", + "4/4334653905304799964", + "5/312635119192760375", + "7/1267920000007019107", + "1/1784439579952164251", + "3/2534860216983941993", + "5/7944238415963342645", + "9/951732216359509269", + "9/4943217792670767469", + "7/3631603070348930857", + "5/1551918882599055445", + "3/3662800110754797483", + "6/7652500106222717006", + "7/1134531217565158567", + "2/4450680424595183712", + "6/3822234370301563756", + "7/4874468746831214557", + "3/4961327097627052283", + "8/3470574830177641588", + "2/4068702375678266052", + "2/4118099537374659252", + "9/7570008436281139069", + "7/3093814438233414617", + "2/2867692098142819062", + "5/2102201183861081145", + "9/247727672565219489", + "4/6490178928922019904", + "7/5052586878171430347", + "9/42541341466447439", + "8/2379581107067821228", + "5/5026946661941919765", + "2/4001162908777714592", + "8/3851812134194645848", + "6/1214597748435876356", + "9/6136396027987524059", + "1/3246774755930247691", + "3/8211109093333591803", + "9/6158572857280157479", + "5/5401708425744029885", + "8/2939262480766465618", + "3/2888910692933454873", + "1/477859093394825511", + "8/8804175170445641238", + "8/7973931665432955748", + "4/5128261934418143194", + "3/6360809473515527413", + "8/5311609714161599698", + "4/5110817549808611164", + "2/2651565557402789282", + "2/8510449266170987222", + "8/2099585839891425608", + "7/9036670647358942987", + "3/843639379066465493", + "9/3394482358697610939", + "4/3407432699820922624", + "4/8632416769107469694", + "3/1857188997338734643", + "0/8369647047646148080", + "1/2564667852517275381", + "4/3468871805519333984", + "6/1178295306591329616", + "7/1382122020907703197", + "2/4819588958666962552", + "6/740238687111500226", + "3/2754075486794210213", + "6/5831611710690267676", + "4/9176626278410287544", + "6/5887643978173828986", + "0/1235541902522224030", + "2/3720845440531286392", + "2/6352715204548992022", + "8/5457826387835947018", + "4/3923910991539845874", + "2/3383822645205997082", + "3/5211207779264688233", + "3/3090125421767118793", + "8/6271630054382494538", + "1/5924974130952549211", + "2/8733362025606154832", + "7/5942415088698701697", + "5/4032992653520980025", + "5/5440797798808079685", + "6/1256261155069994096", + "5/164988712820456275", + "2/6102449611523792972", + "1/7149685881613780321", + "1/2212314670949046311", + "1/3733240847440610651", + "9/6062817587281704069", + "1/6812512852883270611", + "3/4133871064416339223", + "9/2410879113857711269", + "1/2405596258387198341", + "0/2748342085424666770", + "1/6343761810885522311", + "3/8874113692797919593", + "6/6995671106114998266", + "7/8944149220003508807", + "5/7370821296276384695", + "2/5295946555641360982", + "2/8195964786039318272", + "7/5613085298145417417", + "7/8878394038280449887", + "4/7869907298647232224", + "6/368785483073033246", + "9/8659411804431872219", + "4/5726181147854757094", + "2/2775431071008255092", + "8/9164786020475693318", + "8/7198784705275123398", + "0/4688409478273037820", + "7/1090003227616075297", + "8/6360747645334429298", + "4/9150896676959341704", + "5/2977154069217817295", + "2/1737871928330550922", + "3/1605231876067779823", + "7/7623940773528496747", + "3/1023295906498158723", + "3/3291151234541016313", + "0/1575226264825187070", + "8/3815051897785964878", + "9/7700466902413520089", + "6/80125596272577186", + "3/6875207609515472793", + "5/496802426889633235", + "9/6407632832136960209", + "7/1187116067764854077", + "5/2890099955623947255", + "6/157116176389287296", + "9/6959258181023932229", + "7/3090179438311466787", + "8/2113503304857594208", + "0/474222272472287730", + "9/6242633697025904319", + "5/4288331921245818875", + "9/6877468915113437039", + "9/5915442604958635769", + "0/1594196832958963090", + "1/8485359889302429231", + "6/6263841791978224166", + "0/7907815414169406550", + "7/6026682151909046957", + "3/3683918333956513243", + "1/5814895889167989321", + "2/7072781525120376482", + "4/7058964985593482854", + "2/6246540026080971552", + "7/3690734718789692287", + "5/242319359223537385", + "8/7255281500309965768", + "1/8365812354325831081", + "8/4290756309810844328", + "2/3806396256544251522", + "8/953526717586682068", + "7/2991745418210736387", + "2/921179147922482382", + "6/1332802659194158726", + "1/2074787734464940261", + "6/7712060435364551676", + "4/715423948920415364", + "9/5263319278874707189", + "1/2105078234238458121", + "3/6833704907884976123", + "7/5935563356535619527", + "6/5572958012792537186", + "6/2140219108748972166", + "9/2797980149834286599", + "9/3532399680612192929", + "9/2722770622625454669", + "5/4202072734205168755", + "9/4209521819919056949", + "2/1067701920345622582", + "1/3899415892347272301", + "7/4840119153257906327", + "6/5672747259947695356", + "8/8100049906800405208", + "2/2023316076252283902", + "3/1777102296767939153", + "5/1451659979600588005", + "3/5191963093015856873", + "1/903493244330384021", + "4/6048841661247486364", + "8/4230311716531167278", + "0/2915410296383479550", + "4/5283406828233254", + "8/3929361344976729758", + "8/5786233353505776048", + "8/3856374400961904518", + "3/2094349065274789843", + "3/4495812979916686173", + "4/8792710390904370104", + "9/6747312555396107009", + "1/8946729879492164121", + "3/3770514454944261723", + "3/7250566638687404033", + "4/8085672168828626304", + "8/8629614949682894228", + "8/2042485001818245738", + "2/8219850164914694922", + "0/3642708940210726970", + "5/3384032076423038835", + "4/5825197095860902264", + "9/6792576961050207029", + "2/3942998026928469972", + "9/7401342401004132179", + "3/3257784371480269763", + "2/2780624868564419562", + "6/4484608095461131186", + "1/370420358975297061", + "6/2721726562906295256", + "3/5087563825496399123", + "4/3132686846894285534", + "8/8898107482817022718", + "3/3104143169446680033", + "9/3011405510493338859", + "7/8414273321161489647", + "3/7330427121554479133", + "8/4674088517276112028", + "9/8162147244586635389", + "6/8422507682531830986", + "1/1049344033829733311", + "3/2059041062764118893", + "7/5351859666144405087", + "3/595849210309179693", + "6/4719208405381388726", + "0/4101867842616169600", + "2/5556663993885133442", + "0/6170456383570238760", + "2/255185943514066252", + "9/4433591926936836399", + "2/2456831854322629962", + "1/5829434386082150391", + "1/7051357572540646881", + "8/258271936293815128", + "6/8706827524684056536", + "1/7185405312957169721", + "7/5391388898778832417", + "8/1511442185528827808", + "3/5045946222444720203", + "6/7130295455222040796", + "5/7613529040682918185", + "7/7732393771392533837", + "1/977542804477205221", + "6/3860432644504912716", + "5/217489632863290075", + "0/3356330938466745970", + "3/6707671366066478663", + "6/6192727499462125116", + "9/4667627859600977769", + "4/2122745044761853394", + "5/1762595656410398115", + "7/8346364426929512957", + "2/5708305310643015612", + "6/3994365652095242946", + "5/4928463562439844065", + "7/3605410899890848497", + "9/1055653793438040129", + "4/5159630980859045494", + "4/4443588401048481274", + "1/3115156356047035951", + "3/1906747259456464253", + "5/6872674208282162285", + "8/882169469251214318", + "2/6490165080348621852", + "5/1766905117587265545", + "7/6687916697225544427", + "1/1985970467213131781", + "7/6745045833569203297", + "6/5281282772383867966", + "0/5722103564331403470", + "2/4704627153761786482", + "0/6714768754304091150", + "5/6224476334209147675", + "2/384659604457105962", + "8/7498716467479393878", + "8/8321821886472416138", + "8/2848263152158284498", + "3/6064521109874100023", + "8/7886461815393274538", + "1/5876063244425452321", + "6/8280670328901552566", + "2/531571040152404652", + "7/6325892710322199957", + "6/2250329674270499476", + "7/8988340285028889257", + "0/6884876040546810560", + "8/415638296953994678", + "9/7516642613363873649", + "0/6452983346054439510", + "4/7675240878622194314", + "6/3232304059021096026", + "2/7044707880823556252", + "2/8282539564466525742", + "0/4089360612845109290", + "1/1531187698004953491", + "2/75962930918539142", + "1/9181987379953864051", + "7/5046817007938802667", + "6/7858447821962265456", + "1/6188205044931807191", + "6/5288115967056758426", + "5/4601279865919631875", + "5/3078840504752246205", + "9/1936084445407814989", + "5/69714663244878055", + "1/2711692272732822171", + "5/3067820901946026755", + "6/4402692507463800736", + "6/6663571388929691616", + "7/7589751559138163167", + "8/6405814798260588428", + "3/6561307844714050743", + "6/3402569887482951206", + "6/5510524684577300136", + "9/3305475774775943149", + "4/4439818607708391304", + "6/303618934151875296", + "4/8011461789610162454", + "6/6457955469141180046", + "4/6235820823680110604", + "2/8269636340139048062", + "6/7043178136349026366", + "3/4118767974375244103", + "8/3083882224848837068", + "9/4641567174082027669", + "3/5801950896797720953", + "8/1774450198045898458", + "8/4635390702145092668", + "3/2693644674434863913", + "1/1173929873356452461", + "6/3662617853633260336", + "6/6382874449534114336", + "8/6220946925897614328", + "8/8218483744021547538", + "9/3795964765808926999", + "7/8390915217827917417", + "1/5135340688758272621", + "7/4672676854000528927", + "8/8838824438743701428", + "3/5553937433280223613", + "3/5232644748031178713", + "6/5343038643759787926", + "5/98951685351202635", + "5/3988074351429433655", + "9/1923575756975945449", + "9/3915028651383703799", + "0/599237894156050870", + "0/7682018328841877110", + "2/4454977152945104972", + "3/8388708616460855093", + "4/3688204515205684274", + "5/4995819846025777395", + "0/4348127235762522680", + "6/5505821485896448046", + "2/6366796272648937732", + "0/3011248874864667690", + "4/4213924048194480194", + "7/769321684743052877", + "9/1989705078089752559", + "0/2687158487914293690", + "4/9192119832985380134", + "5/9071023046507463815", + "7/637410677352190507", + "2/8945610953834761282", + "3/2264473144935997213", + "0/1323974719362660260", + "0/8368531017655780200", + "4/999563983084132734", + "0/5403802845192007530", + "6/7189364054834495036", + "8/4914917421606398548", + "1/1515552392393428471", + "3/2476837096630911843", + "8/4517427550853362458", + "0/1107966252827422570", + "8/6199788706619787498", + "8/203683475266366318", + "1/7455106108200845211", + "0/2173874227736583570", + "1/468523830136070341", + "0/5489791290466646300", + "1/3893978059383162031", + "7/5487960612981524797", + "2/6432795723413871602", + "6/8940685119464047656", + "8/5632625304367476258", + "8/7896864909238470418", + "0/3245485593774659670", + "6/8800143430349582166", + "4/291163531130708474", + "9/4043461251643837689", + "1/2603641508996140331", + "5/916046882292102565", + "3/505270318767878003", + "3/3009283137372099343", + "7/8455576477125151547", + "2/3847993790300128552", + "3/6308302299904440113", + "6/6374790876908166936", + "7/1989792230104160337", + "0/4530111849803819560", + "8/4287371345677682188", + "5/8867362672352803025", + "7/8094143335809103707", + "5/8006408968457100375", + "2/3059385373027143472", + "6/2764401345914015096", + "6/806658017111766736", + "1/1799345027946160731", + "2/2600082016450561622", + "2/2901852025516493612", + "9/70726883241967089", + "9/6739125519938198869", + "7/5984889649537038007", + "7/6943166534941386167", + "9/2051842897924180489", + "3/1232400313682664073", + "7/5443160678431595837", + "5/2052740533404348905", + "6/8505366405268256246", + "2/48373735174037732", + "3/1600410704528210323", + "3/5836455443884917763", + "4/3873365491404745914", + "7/5558387515599376387", + "1/7162419789786961831", + "4/7176283307078560674", + "3/2087296847428612013", + "5/3144349305792015975", + "0/8646291641404615410", + "8/3599903979792351248", + "7/2003845539823451557", + "3/6246540033384508793", + "0/7842905252392321740", + "3/1907241511581177323", + "9/8385806021094571699", + "3/3163414286357494903", + "5/5043829834690475225", + "3/6954490803761888373", + "2/7094873493536866122", + "1/8887458137349463421", + "8/7531619929224961918", + "8/5718199110449335418", + "6/5156945441288474816", + "3/2111018985456148773", + "7/7922817490067206807", + "2/2363299828127522992", + "3/7852535245899510533", + "5/4629790761962283335", + "6/1141614152376315676", + "7/5844244267286478237", + "4/8800344865959596794", + "0/539003850869479610", + "0/8473591188914141490", + "5/7197000743401337515", + "9/8558776105559992169", + "7/8860758017089565047", + "7/3946167737741576627", + "8/8498073752912462538", + "7/6395710151706636267", + "4/5980295007809284774", + "5/4373822914214547715", + "9/1601638697712573709", + "5/8709823995799399335", + "5/2690056517818280785", + "0/2180868821219142090", + "2/8230520207623151262", + "1/170649228897287261", + "2/1259843191784755922", + "7/1310868947125996167", + "0/419290446040770120", + "0/1686801643915325630", + "8/4351589836603527778", + "0/5176064128372997640", + "2/3011753058647032182", + "1/6649415686288248821", + "3/757503869823449963", + "5/111921628408527245", + "1/4575604150145195671", + "4/1228912461000117124", + "4/7415968921062094644", + "5/2561811196860998365", + "5/8456370581644794155", + "6/5332404150918946636", + "2/1613570712703059992", + "9/6466879417041366269", + "2/1012995748550875102", + "2/2392210178147609172", + "6/9140597884575101446", + "7/7634818719792582237", + "8/6987845027250638828", + "9/151238911374650939", + "5/2698952769769442735", + "3/9163953092985413203", + "5/7827826622044617415", + "1/3240806812376959551", + "2/2348449683127840832", + "6/1091063389393262396", + "6/8924646905625068866", + "7/299968673084547367", + "3/4965522933297130543", + "6/6987482240776493026", + "0/2692751386527662910", + "6/4727245908976194516", + "6/8170444227883610596", + "6/5945957436643781056", + "4/293546145870144624", + "4/3203214879196864884", + "5/8864287433439511305", + "4/4894389097770805644", + "2/7066446972075342242", + "9/8394976304042071939", + "7/3665996814330302957", + "4/4195449474589038224", + "7/3705159738470177027", + "1/1971129114931600731", + "3/7698016333528814273", + "6/8307872442574831406", + "2/8029836146860298242", + "2/4811073942896883392", + "9/9043501620384744509", + "2/4135094750493348702", + "2/9063848489763603212", + "6/1047737722019217686", + "7/1105763048076748007", + "8/6848992570875343618", + "1/8696593597056610481", + "2/8897251037568293982", + "3/384234306730533173", + "3/4163587366289010003", + "3/4915163897291682003", + "8/2521150635056786428", + "1/3973449586430068631", + "5/6374444475661765845", + "0/1613075819964979400", + "1/595110276517849951", + "9/6890437430603817499", + "5/3362807306831534385", + "8/1191455569364500668", + "9/4771685909627393639", + "1/6719067072143159921", + "3/9165394313582993053", + "6/6611958008124828496", + "2/2314192731223545002", + "1/7860469306191157151", + "7/1268345487225644317", + "2/199859097517858062", + "0/1355618821961763180", + "7/8680205580721683527", + "2/6651484967025863202", + "1/10984283605502311", + "7/6872552402944758997", + "2/4938617892845474462", + "2/6260199844620222502", + "6/5233970940556583636", + "8/6443113275226438528", + "8/1402113543714102418", + "3/1483162999254610783", + "0/8943996519025208160", + "1/4069016334478796421", + "3/1377085263774540463", + "2/6112503679003207952", + "9/4242104015295642629", + "4/7333049491599241814", + "3/151023333795397463", + "3/1209609499912209663", + "0/5362778520327714040", + "8/2994635260749651418", + "0/807178892880572720", + "1/3144563150621076701", + "7/8332559347021562037", + "7/6174281199510346917", + "9/4981186663954880989", + "0/6697420977431412240", + "2/7043657801443847562", + "6/3266579397748095446", + "1/9074625450251067991", + "1/3478617187121467751", + "7/7959135028997995557", + "4/2206871239571124784", + "4/4635398644847192494", + "1/4049698061753217911", + "4/2237703447330102744", + "9/2942388254591359669", + "1/9213860147217965191", + "2/3142841151237817342", + "2/1302689117498204802", + "0/6775150173141543810", + "9/8447657313067423609", + "7/6423717325331524177", + "5/920386760581338565", + "4/7444832889680652574", + "1/458006334601500661", + "4/3759591186273122284", + "0/1552558937213282450", + "7/5225739374137694737", + "3/4909753373330315493", + "5/9020240928867493445", + "8/5468778098400235188", + "7/5446945317633545107", + "8/6699010200397583348", + "9/2157302826934739489", + "1/1559958871180200621", + "3/5526201643295310343", + "4/4383025477707035504", + "4/6197768943214480104", + "2/4313346611214057212", + "4/3394159562845852154", + "4/7676043641332837104", + "7/940358036000899977", + "0/1287829452191532380", + "4/2954415559356004584", + "1/3945562990936736471", + "1/7195553112756390291", + "0/2980522565449554680", + "1/384611690663081201", + "9/5382514373482834099", + "3/5558202476649691603", + "8/4815664970820778338", + "9/8154183338971795159", + "7/2520424014850472537", + "1/8643683884415679701", + "2/8991713154429241962", + "6/2841807355954606646", + "6/477894117741598666", + "1/5841538337823699031", + "2/4633186250997272232", + "0/7393427492118391970", + "4/8469278338096545524", + "1/995631284001350521", + "2/5792082248553311162", + "7/808466696039480507", + "3/6126041197871904243", + "0/1571385644749914530", + "1/1095264798929760841", + "1/5809519060544254401", + "2/6435924794891109422", + "2/6617763557115899272", + "2/768101844042761502", + "3/4439079187668441793", + "3/925111175431723593", + "9/5223936901037377859", + "6/6167909702414489736", + "3/4173163252907756263", + "0/8634619520603812890", + "2/543089652284491762", + "8/5666653684500306068", + "9/4045592899086151659", + "8/3076857351864715858", + "0/1763225234978552570", + "3/3180297332294680673", + "8/6934159943485898468", + "7/2261308401513898057", + "8/4014942582024706658", + "0/8174510361203308830", + "9/6758153121715480239", + "1/3505220226135068421", + "0/3693690919379523380", + "4/1677180392320868564", + "8/2026838045591346548", + "9/1006948407630337839", + "5/7863948696578198665", + "3/6224576568901140763", + "0/4150908258133118290", + "8/8628516963197630508", + "3/3927140802254440973", + "8/559825228109963078", + "6/8796913968936060196", + "6/2375572921041392286", + "9/1834318473252315719", + "6/1685048464317247196", + "0/8869205701844681650", + "8/3572575519072653388", + "3/3836301997357138553", + "4/1468817197174101964", + "7/8208780947195394397", + "1/290693904956088181", + "4/6950411152453646874", + "5/1112033150398107825", + "7/752471597870401477", + "8/8312619684096902398", + "0/8507605388537280210", + "7/2547664192610087547", + "7/3174355482581202467", + "9/2766156668071307659", + "6/5272656872388839726", + "0/5878479567254090270", + "5/7440721942820473465", + "7/3233112343902279177", + "2/2351653257001353042", + "7/7521451100059948707", + "4/1135262692643216814", + "4/2285357290049346834", + "8/6218495914112000498", + "1/572168083690933661", + "7/8988208521390700507", + "2/5998713863691153332", + "4/903818101232500644", + "6/4996040738801171786", + "7/1844383384899486537", + "6/3028085606242606936", + "5/1791936920739581935", + "6/1875216831932238226", + "0/9002320344716770240", + "4/8983251661182098324", + "6/3817929176432636046", + "1/1073382029573933591", + "9/4981518110952104519", + "1/6195802145247918331", + "3/904294405622350033", + "3/7706420700948085243", + "4/6932137419019036224", + "4/7991100460080411164", + "0/8802858686588882780", + "5/1156491825003973525", + "0/3281378081925199250", + "6/2788883404393511556", + "0/1436320070189150500", + "7/4347208334091004997", + "8/648346055155859308", + "4/298157893808836854", + "4/6484941873356349064", + "5/1297846531609076325", + "6/7633446389881145946", + "7/153298829330230487", + "8/4875679767138295518", + "7/391487959521697967", + "5/7371304787676777235", + "8/1017258167469455738", + "3/8017880837792526993", + "9/8280708773361002109", + "1/1535970769978605021", + "7/9131942521470166877", + "3/6966706117425650133", + "4/1595071687360030604", + "6/2164383448977187446", + "3/7528686060615752153", + "4/4523698854373847874", + "8/1687282573424715048", + "8/3329600106422862068", + "7/8017850661860954947", + "2/4522534537199722612", + "2/8175993306461330672", + "2/2061844726435307122", + "2/4357308778576060732", + "7/5749929302059417037", + "2/7305439510235746922", + "7/6647422868989183237", + "2/5561431002238379672", + "6/1911168464910564416", + "9/3408188254951796499", + "8/6845417045203785568", + "9/288786441560725939", + "4/7562245497992388904", + "5/7047393832904767985", + "6/6205717948061698476", + "4/2261437227801221284", + "3/4948821583918098113", + "6/7777541061456816046", + "0/6983064454891729640", + "1/6965222730467574761", + "7/2695819361343807207", + "0/535847299517724940", + "8/7571486886492760368", + "8/7605613066630044388", + "3/4119171647104079453", + "9/3013875126032191859", + "8/1232065125542726368", + "1/6809299086020048711", + "2/3752686847197649112", + "9/7572298785119162909", + "6/7206134468092698096", + "2/7782932906908736592", + "8/2419116376449848198", + "4/1902002847773072324", + "3/6289460471097473203", + "4/2417030371199714614", + "1/2455173111990510331", + "3/837693308419586023", + "9/514778725067200419", + "9/8043295666948967699", + "1/8049405586737044061", + "7/322678593738310197", + "7/8740524876503946197", + "9/1270302343401053019", + "3/762171720723726023", + "2/9110228255863634722", + "8/7259418402484651628", + "5/707167041453258345", + "0/8797438423630194900", + "0/5746196675588195290", + "8/7019969990018061488", + "0/6280918187259393020", + "1/7155733656259991621", + "2/8330551504836740412", + "0/7647442117269349850", + "5/9047317752417487475", + "2/3488691670694418252", + "3/3165390486797812773", + "5/4863066808794270675", + "3/7697038444955343353", + "5/873985744684511315", + "1/5983785747627030801", + "2/7715517485850034862", + "2/6526725883123964832", + "2/7949441207284546892", + "9/1593980152149072489", + "2/5181119608253271122", + "6/2594677209000005796", + "5/7650724896127215245", + "6/637266387463678746", + "5/9166463926375586795", + "6/8851583786762687696", + "8/6777375236821733038", + "9/2984626323630976679", + "5/4400551462195871165", + "7/4001294447373424977", + "5/6399887576894682225", + "2/927268377693291392", + "1/6153739573126678641", + "4/6469013191380825794", + "6/5007312118647235216", + "8/2203183298847543738", + "9/1925728959478350849", + "0/8559449156931166340", + "9/6574259959749855109", + "1/7535630750915476181", + "5/4248395850933067095", + "9/2930419321195444489", + "9/2946013108117407969", + "4/3443316733636413914", + "2/5505467480227107372", + "8/8965178594768064678", + "1/1195561468208589301", + "5/3467852973802058475", + "6/8336649847748694896", + "8/2094826258462674698", + "2/5097801508757047822", + "3/1835899575202376943", + "8/141861298663485628", + "2/5954511009498076152", + "3/2598206716396361433", + "8/7742586811777572658", + "5/1689284054105770555", + "7/7450366073947885797", + "6/4784317285316364816", + "7/8956536144676649387", + "5/283574478237497855", + "5/6536389686350476165", + "2/4875391990883183642", + "6/4360968344422589056", + "7/2570611352686308177", + "3/6633315699578416573", + "3/4820344821608176823", + "0/2219049197924093540", + "7/226978087051072877", + "8/1461880307236421898", + "9/4299573952017235339", + "8/983370878266690778", + "2/345530964811523372", + "0/8609418469005662480", + "7/2474065617615745797", + "8/7287474234286212158", + "0/7314039967881696310", + "2/7338824273616211892", + "3/6774751492126094773", + "1/5160290148479104421", + "8/2053405120894102598", + "9/5297714910153666449", + "0/7562487208440209260", + "3/8915186287769209013", + "6/8384510362707853216", + "3/2449255814611930293", + "7/1838176824170147957", + "3/5968344511857198733", + "6/516494815568466486", + "5/3199743840641153895", + "1/222234957343670891", + "2/312440074879432282", + "5/4813838604196580385", + "2/6740772586675637222", + "3/1335241312977978453", + "7/8727066224337959627", + "7/6201572528039224887", + "5/5672074027257645355", + "1/4968323712259540951", + "0/1200559974098547340", + "8/4656956281634425258", + "9/5849381172878904679", + "6/1090880512254430236", + "5/5273284138764804545", + "4/3980521927255593584", + "4/2180933545875009414", + "5/1688995057429375245", + "8/2549547907674981868", + "5/1340723204663434905", + "2/8612794801976678562", + "6/206612489569069686", + "8/6624545242265368948", + "7/2209387025435056697", + "7/2987301109494136257", + "7/6973354216126412587", + "1/30361990069247811", + "3/3745078404411915023", + "6/7728312126366716666", + "1/8083071273834298721", + "9/1076981601852943729", + "8/3211162260882904828", + "4/1068399956173785854", + "5/560063358799009435", + "6/2264333571824810906", + "1/7113624872894404191", + "4/2562963740628450644", + "0/3993007504652272190", + "0/7201634209166570880", + "2/5459291555144339262", + "6/7107538805296791866", + "8/4622736056190054008", + "8/8105010259119461438", + "5/6939392311417218605", + "8/7800727000420766058", + "0/6820958742949032090", + "3/8432680131908319273", + "3/5812677046844479693", + "7/2059419536428766587", + "7/4263096714444181367", + "8/8768739658949830998", + "8/4128252223539998558", + "7/1916207582264278807", + "8/6920566518947971868", + "2/431713865646106522", + "1/73154956057083771", + "9/3118954820765386049", + "8/8947260193228419698", + "1/2031960878056126731", + "8/5182780320528288548", + "1/7866324839655188631", + "4/3894744169104605034", + "1/6713683777202441811", + "6/7351946193851819916", + "1/1447560231688688431", + "2/6865603867409991322", + "5/6800743224844513965", + "8/2370077341457668738", + "0/6264908970163155290", + "2/240712577021296382", + "6/8777924418492762006", + "4/7992613180591849424", + "0/2344105190912960810", + "1/3203472211669052061", + "6/5912783385472942246", + "5/4552698953368297025", + "6/7852820789156533406", + "8/2017288955897337708", + "3/4337503029151181733", + "5/3023256861808262485", + "1/6711246478341900951", + "6/7965665319397032316", + "1/8664655373742569271", + "4/1906351233884725554", + "1/1614084373638353611", + "1/4055018449565205611", + "1/4869670325514270991", + "1/6629595159627557391", + "3/7585930350710497743", + "3/2336604342891350243", + "7/3038525605934250667", + "7/1177637310994117087", + "8/8817308441574134898", + "2/6440210420953702632", + "5/4328053886523021665", + "8/5505685265083256758", + "7/1001275948686780577", + "1/4854598769700713501", + "4/2586529368902550764", + "5/3990751633817285425", + "3/8964767676466901493", + "5/3598188495457736555", + "6/1334412760342454336", + "3/692772976158089943", + "4/3804872845178539864", + "9/1044703742840097239", + "1/7244285169640735531", + "2/1299367643256975872", + "3/1394527666940960993", + "1/2810498016475225921", + "3/4806983698229749003", + "5/8596713903295385955", + "6/5317778978037314026", + "9/1648093541244212539", + "8/1376315712184702158", + "7/5615816835225758427", + "8/6683731031934060298", + "3/3444215787208976973", + "2/4234733289692520292", + "2/2062772670613631052", + "1/3273974720890270941", + "5/2059046742028849085", + "9/8922875978084912229", + "1/8936676274352684231", + "6/8544085375855246556", + "1/963886630917831551", + "7/7456443190043764617", + "1/6153123398881181", + "9/2263923873324958359", + "4/4219884292992314314", + "8/6914490025139613388", + "2/2660163059467836242", + "1/5183651944648808771", + "0/7167658058611181120", + "1/8735897422752828861", + "1/2453257353128835831", + "7/8990886184699497797", + "7/5129150007579972317", + "7/4853019104901508197", + "7/7814008583370853817", + "5/6336887289227144965", + "2/5045835748157545072", + "7/8499797159161854207", + "3/2794896040283657983", + "8/1461570105069302178", + "8/2379977264833811068", + "9/5127162940373768809", + "1/7131157560413998921", + "1/1929107246330492451", + "3/5971956834614159363", + "5/2063323527410061325", + "4/4959453736620007434", + "6/4366720395204570536", + "4/740921953343758624", + "5/7321942889647492135", + "7/364958652151649287", + "9/7408681020526997179", + "3/7955601389529721373", + "4/7776681781440354014", + "7/9167465307568787077", + "2/4167521460305639642", + "7/5363838867224642157", + "7/117131011405792567", + "8/1585797274818653628", + "8/5481398945485866098", + "8/8441846816113502258", + "1/5478053557001654971", + "6/9054394352302976686", + "2/6490546615782832782", + "1/1075842623888539261", + "7/2419408216478513407", + "0/1258947501014432130", + "2/8918950932208143592", + "0/8008170419771024060", + "0/4303931394050307950", + "4/7082910778403784144", + "4/6497519695252388264", + "3/4839676773403380673", + "7/4786192340403033387", + "6/165086742457007156", + "0/3324293443762710410", + "0/1682774148552185320", + "0/3872616618668548090", + "7/5640020187507474767", + "9/5986289483467422309", + "0/6661137413184023060", + "3/4002115698337785543", + "7/7769201645637930707", + "9/1093727570667436379", + "0/5677667025090732770", + "3/7200789449243891603", + "7/6440503567964109927", + "8/4399736432093231198", + "2/5777106476947864172", + "4/4359090367847228244", + "0/7204668905167953700", + "0/8800589386308311610", + "1/7139079555469494741", + "6/3511666369798054786", + "6/763680819594893726", + "2/1130939135602498912", + "3/4837425562510762643", + "1/8971798919556773091", + "9/6540135878220702809", + "3/1517000336203320713", + "5/7506377580714468725", + "7/6628337261055718887", + "5/8287646745144540695", + "3/4521283458467711523", + "0/8735417448316250000", + "0/1921458820952793080", + "0/7084451072467107780", + "1/3097556609595113271", + "1/9151313470875333331", + "4/4639133330857729324", + "9/5600409147013705599", + "2/5156647180364061482", + "1/2775573147944747521", + "0/7841446632361984540", + "9/7234414294892623369", + "4/61379182488740904", + "2/3289065806316125852", + "8/9065258277433766858", + "7/4133365735688699837", + "7/9100912758638009347", + "8/8395440585268819108", + "8/6435537942207493138", + "9/7899832700988591359", + "7/2351847574823934007", + "6/1858893671989783996", + "9/2372227954075228279", + "0/4794230662119909630", + "0/2325173510092748040", + "5/6016779880592276185", + "1/8896046644735488841", + "9/3351610922671104969", + "4/4171033703090443414", + "8/4255391488689353778", + "8/6124555166608505468", + "6/4554557805487450706", + "4/8775527557759575624", + "6/1352469873513669546", + "1/7298948548962810991", + "3/2604413506216462163", + "9/2002908236086948609", + "3/4006568588333650243", + "3/9013607533151912493", + "6/1705329668335242796", + "9/6822548774399997029", + "1/8537786175206605231", + "6/4692790360637246986", + "1/4679611180621144671", + "1/8858433977814734731", + "3/2541787245580055133", + "2/4461446014363062792", + "8/5312431304147866328", + "1/2135390553567527841", + "7/4011195138481308587", + "8/3037073310904545848", + "5/3895651220915859805", + "8/7115401450561970218", + "0/8374437323460341190", + "2/8624770234043260872", + "7/6469478441842421337", + "3/891244164600651843", + "6/3975423402214596946", + "8/3272633530514547958", + "1/5419454288879304071", + "6/1850034211886742736", + "3/7759254130066902653", + "5/4793445567435173465", + "4/6188463547069235224", + "7/2537027757296998177", + "1/5261553590106137281", + "0/3060576616508317690", + "9/1183803853509952989", + "1/3999674713395316261", + "8/4962713935038692278", + "1/890690136953530791", + "8/2198604042123848438", + "6/772568032929246206", + "5/5898626899785580965", + "7/7669373097034853147", + "0/1959740907814495940", + "0/4861638631585186240", + "5/2155395784193424625", + "9/177644183620703429", + "1/7795118936357032931", + "6/8214952066391058696", + "1/1628501975532529831", + "8/5552886918755707458", + "2/5513372799820378112", + "7/1386138814164009857", + "8/1474585050709633798", + "3/5025720019987463523", + "5/6696043012044840335", + "2/2280243886839226582", + "7/8142203689164212777", + "9/4494660782733125969", + "5/4715615334475754485", + "0/188274074674918210", + "0/375443313112639850", + "8/7995475573242037478", + "2/2097079811906449622", + "5/4521808845677495415", + "6/2056531928333177066", + "1/8803456785331952541", + "5/7666916213494412235", + "6/7428487963622878416", + "7/7513903156697115897", + "1/6243831706352568911", + "4/983962324800807774", + "8/1219578050090037118", + "4/8547589450614840364", + "9/869319032878659939", + "2/1554543400600154282", + "2/2989984812744789502", + "5/6564220671469507775", + "8/7903148181002726908", + "9/3950257696994982109", + "4/7096189538286646334", + "0/4974359925797921130", + "2/703838373038633002", + "3/7011781894313868223", + "5/6089821193505472565", + "0/8338708859949158660", + "3/3331581793716470333", + "1/3596073823653080351", + "4/581701394246811354", + "8/1364977411651464418", + "1/2492044781315407641", + "9/2823293794944991369", + "6/2485307982678882056", + "4/503702233138766144", + "8/583734468194150528", + "2/6623635364083317122", + "9/3368182564493592279", + "2/696899412855300532", + "2/9118351715271281342", + "4/1272660441710797464", + "5/1623750151180632605", + "3/2906093671464963303", + "3/4165927054478767783", + "6/7187932405572504666", + "9/3632758940054047319", + "3/819311080616735463", + "7/5606128215337099917", + "2/1204201796829180382", + "1/4632730131677219681", + "9/7669865424878131979", + "5/427773126925000925", + "4/6119376093825713754", + "5/3318758170145058325", + "2/3438970435231027592", + "5/4163109278434085655", + "7/108532988064169667", + "6/2401961446068185926", + "4/843963395517970724", + "6/697795133868752196", + "6/1249391125095381106", + "3/1064759411779628493", + "6/240689912767030046", + "1/6080319863108825751", + "2/7448477953944869692", + "9/4392823384546783929", + "2/8405196567738633792", + "6/4332409698316964356", + "7/6915991680210704627", + "3/7868991420795312873", + "5/866086646573716435", + "1/6983323204903776301", + "9/2159283151663500369", + "4/4571472004003098914", + "5/3854599633831606335", + "9/6477234439763478179", + "0/4746254182564212140", + "2/7287949722829928912", + "3/1595572115210014983", + "5/917923369744320515", + "7/1272925668827597777", + "8/668995997559424168", + "7/6091684176146257937", + "5/6050815609204893295", + "5/4510129620075591875", + "9/6656051556518883029", + "3/5481543690468533283", + "2/3374881481270467492", + "6/920921882879350946", + "1/1790052109088407981", + "1/37242364202981651", + "3/4820852255998380203", + "9/6977545908709890999", + "6/4014068062496797026", + "7/4855940873251212847", + "0/3051991851889860100", + "4/3171017434537774374", + "8/8842706903561095378", + "1/6713926218581034621", + "1/1131769260549437951", + "3/5319986814257972803", + "3/710695027540859103", + "8/5852162374689351628", + "2/1822123506533156802", + "0/841294608995543290", + "2/9152482058077058562", + "5/5469160871955202975", + "9/1842653219229426539", + "6/7310597944819370986", + "5/1651602087077283355", + "7/5393686233687223517", + "4/1782404710685811324", + "7/3218018896918193557", + "4/1811786328334971294", + "9/1555872544244633319", + "6/9093995410652265346", + "3/6471435486981101363", + "6/8702705418400433996", + "5/1327729301168244475", + "7/4183364297983554747", + "9/9190776492420372489", + "1/9162909123723678341", + "3/6450725497572843923", + "6/7791047630017922236", + "0/8004398022173461900", + "5/6881751630709166525", + "1/7495418920904880181", + "9/3878137092201727039", + "8/5904482653196179838", + "3/8379675613291902493", + "7/7252072149716820217", + "8/3275396851918810158", + "0/2828408126947103260", + "9/2692099439816868159", + "0/8393531459986951060", + "1/6177351665023679171", + "6/6494565591192946386", + "0/4647293997859285300", + "5/3716321473806351185", + "1/9050129669071228571", + "8/2081766541895096378", + "9/2376122398241578239", + "8/8064893311180417658", + "8/1829608282526051548", + "8/5588370251061488848", + "0/5501800478057769840", + "9/2367682417854737849", + "1/3318802821108196631", + "1/2313079725615839371", + "2/7686004621368092932", + "9/3129383223150562999", + "4/836132670204309984", + "4/838161638252237304", + "7/1071157004019001167", + "2/6380832707188530972", + "8/2706473048855380548", + "9/7713526076004060999", + "2/7706052876346738422", + "9/3806729425019889999", + "5/5549197548038711445", + "1/7498162149102330041", + "5/3578647482100625075", + "9/7058813318625074339", + "1/253901345176349131", + "3/1672866218214765503", + "0/4708908571398818530", + "4/5035665141717567664", + "6/5660387282837081206", + "1/5420642910684011601", + "4/2723250576156972784", + "9/1625952172954771909", + "6/6860312402723229906", + "3/2335569681387188123", + "5/2236478215825541405", + "9/7866587809396920029", + "3/3356415730615695143", + "3/7111469413025416983", + "1/9033197102946877321", + "1/5213247422684402371", + "1/7971714896421466441", + "2/9162828813460295862", + "8/3781747569395994178", + "1/3483101539728317681", + "4/7738157712230778224", + "7/6736913133116102987", + "5/4822316602453782055", + "1/1730534897091030031", + "9/1594042648047613599", + "5/2105481742237380605", + "0/2839526287160842700", + "4/5801638869759218804", + "6/1115463623117643246", + "3/7984975477586230233", + "4/2921336439543148034", + "3/2807713388333566813", + "5/4382261067401772085", + "8/6767841328531552568", + "3/887776842981115053", + "2/879725356723745112", + "9/7113889004343057409", + "5/7505937294264768965", + "9/7261893733400566429", + "3/7732471422487657973", + "6/5075984983136235936", + "0/9195953207926911830", + "9/8703113330730769819", + "3/4431791432926684773", + "2/6765757824171538632", + "1/4636330178885385401", + "6/6993948603948315386", + "9/7213699857127808919", + "4/3375555473999615904", + "1/641150769790876691", + "3/7891847710925220063", + "0/434871904799223760", + "6/5360927264723930046", + "1/4089990932457933811", + "3/3203837813450525183", + "6/4621878876867532686", + "3/4255520082615966313", + "8/2419016609674066848", + "2/5789366557962814632", + "5/9135088780892380605", + "9/5126880188622275049", + "6/7143038144394430376", + "6/1422160839465255066", + "9/8804322196424760539", + "7/8825408687760073627", + "3/2347940326961804623", + "1/8157948110404911621", + "4/3602359969987581164", + "4/3641312082862486094", + "2/2991393834296826782", + "8/4796006493771266878", + "8/7491089695545700028", + "4/1939194646822269644", + "4/227493903386791184", + "0/8777223344947987690", + "6/1613182395139762386", + "4/7503482718456416904", + "3/2050825423227809033", + "4/6518719133695467254", + "6/4295793528971835626", + "8/3025856300443519648", + "8/4272697248719297068", + "7/868626402934766657", + "8/3014201390185801088", + "2/6644367134279721672", + "9/7897240399382325979", + "6/6722812796216395176", + "8/4233387915859760368", + "8/7934532682444778358", + "9/7930352913187985159", + "9/654430268613137599", + "3/6939460563911338733", + "5/8689712407230470435", + "9/5951168406327139949", + "8/1952546574100651378", + "2/7667654614599930802", + "4/5881834820973554044", + "0/4657751061325609130", + "2/1327092672246986382", + "4/489675514011716364", + "7/7938456829168944697", + "3/6282954362890159043", + "9/8460260532624683509", + "1/794615227486731001", + "7/5167133522137341067", + "0/6558100820075763550", + "2/772865032277792982", + "6/8914390666798052006", + "9/4690499570533686329", + "2/4888978989525463772", + "3/8573859242196016153", + "4/2929932472267418714", + "7/4965706196020366667", + "2/6729004001492427702", + "1/4301273426700228781", + "4/2056314216403046394", + "9/6711953655670008629", + "4/5667363927733355624", + "1/2939082903727666681", + "0/273825114217483940", + "8/2975104187052576538", + "5/6553563421779226105", + "2/5769309001698910792", + "8/9047978792178743228", + "3/8494687235128265913", + "8/7264176805929085018", + "7/4896421765475612457", + "6/3842150241576777346", + "2/4392402144170594282", + "9/3001014932070419669", + "2/2160787905481561442", + "2/7753128985817275322", + "7/2751465388730196277", + "8/8331306850548080228", + "9/1099991847171718239", + "4/4908093880881828774", + "4/8055906888350119364", + "9/6805284046727519829", + "0/2774196279009216940", + "4/1100050704346030754", + "6/678718651518273866", + "0/1279595102384408190", + "8/8046661348863497658", + "4/6189519901163087834", + "2/7208458650122069942", + "3/3117980490727190613", + "8/7702219984281562748", + "4/867762079542052284", + "7/8360970015168650827", + "1/6213335165106599881", + "1/6389561570881217511", + "3/2257142630648888643", + "4/6714172372514036924", + "2/1947748507051062652", + "2/1596450374399308432", + "9/1383332182496830739", + "4/1301346675128687484", + "8/2956504678580594418", + "3/4009208193967831273", + "4/8960699699130987554", + "2/4394701110055791062", + "4/53665699766149364", + "0/6363574465938822600", + "4/7874190671782876074", + "4/8127918762624373814", + "1/523395940650418381", + "9/3152020131287452979", + "3/6237314269486860923", + "5/2085681570473420985", + "5/5564973231147984475", + "4/5095083118907676264", + "8/2135602213192076618", + "8/6482373162678611098", + "1/5886390490289149321", + "7/6591374275968379347", + "9/6258995428325724779", + "3/7513914456245757013", + "9/8901641865001747769", + "1/5264936350262658011", + "7/7712349786061751627", + "9/7006428610946487399", + "0/7924901151459468280", + "5/7363859368201845595", + "5/7712841033451209755", + "0/3480429732768105720", + "0/5151984901263205630", + "4/1352134288762112924", + "9/4570797459392749309", + "3/2216133263267455143", + "7/2299300447267810947", + "4/4854218265720612844", + "6/1152924185221071976", + "8/3897996975570063848", + "4/8756842847257575064", + "3/1996007958835719473", + "9/5403806078364138889", + "2/1079016274966487822", + "4/5614521851486837744", + "4/875464064089966864", + "6/2820402645093840636", + "0/3878842089005275770", + "5/6058480761868724685", + "5/6356867536946099395", + "0/976329964746524550", + "6/1071418655935432286", + "7/1645873195689114117", + "7/532552848643382777", + "8/3653454244432686628", + "1/8467336555942327761", + "4/8776168852245788554", + "4/854373604872008344", + "0/6045005168166653590", + "9/3596421830522992769", + "6/1309626083712156016", + "3/8002273135540226203", + "8/1906362233104643008", + "2/6302742622923994562", + "5/5850994188310591035", + "6/8234141108907865576", + "8/4653477885447665848", + "8/8767535399183135688", + "4/8021279369889142944", + "7/3665913406416693357", + "4/1063606854645626584", + "6/7435631846175466446", + "0/6926869549205814370", + "5/5579639280588619405", + "2/1800671567778225362", + "6/1796089602764175936", + "8/309555111778870628", + "8/8724778312104858408", + "7/2384757039951031007", + "3/2972287590586920913", + "5/8618164568547959255", + "8/4910318095676496738", + "8/1653654050556969098", + "1/1016079725076527481", + "0/3138445241869893150", + "7/8034577580447212417", + "1/1704414771543567471", + "0/7322455153625516000", + "9/1524492589688145419", + "3/4449535904803874973", + "6/3141521312554708886", + "0/8234579611156046940", + "1/3786267114280801711", + "4/2409800297994564564", + "7/6426227355211360367", + "8/1901441349922348188", + "9/2063055892288483689", + "1/4693719590521694361", + "9/4685946231788106969", + "6/1155623746019795156", + "3/3685543217747435783", + "5/7776055403378643965", + "7/2016511229866901197", + "8/7253576525450529008", + "3/8130372855991481783", + "4/5489159121182749904", + "7/7924660200369299767", + "1/8131468500020880191", + "9/8810847794842532209", + "0/6074563052883127830", + "0/5834949080810369420", + "1/3920278363424195171", + "5/1711985605383223655", + "8/6897237049128188028", + "1/6049365535823533731", + "6/3537568739418011506", + "7/924720500117988977", + "2/6886417100182880262", + "5/5798835956913008825", + "7/1211598850872718847", + "9/6953043680503830509", + "4/3592706025356098904", + "5/4286363667122813415", + "8/595877864809871458", + "7/774990163610000247", + "8/3674679554976220228", + "9/934061729413119559", + "3/1266911412170010483", + "8/3528939562118133658", + "0/3895584734015920940", + "4/777433512032783244", + "2/102160958756299902", + "4/4371132039277509174", + "9/4751316413217386729", + "7/5943736298509996677", + "3/5580708046197655093", + "2/805099419383588652", + "6/3759590635384705026", + "5/3319817843682994625", + "9/6024166046800529109", + "7/7762589144038077877", + "0/5952952956798802070", + "3/8678127897473467973", + "3/1603275115973456053", + "8/4946466284164823428", + "9/7371201956919680089", + "3/5283458459618495103", + "0/5821968215609305170", + "0/93344917579228140", + "5/1198063839036133375", + "0/6276693827641726240", + "6/549707344823926496", + "6/188270481661468686", + "6/7499791024527235226", + "5/4008648147368907105", + "6/8115889037255303046", + "7/664070098966932257", + "8/8836837925835380548", + "5/7972900813621041355", + "6/2530784329691948636", + "8/5003304999237064398", + "1/8085418943914658461", + "6/4879172236475524756", + "0/2211698781615780990", + "8/8142361040371535728", + "2/2775747553480491432", + "0/4311228893966976340", + "2/736186984095704672", + "6/6310880967107998016", + "4/4719584812246709694", + "6/4588914364849097216", + "9/5495869968733141819", + "1/4177448355209069491", + "5/5868913774049229405", + "0/1872196518050675910", + "3/1467282679131856683", + "3/7417270525023396053", + "5/8300798563953855425", + "0/7139091799513241490", + "4/6854609366117194864", + "7/6485197668564668037", + "1/5046427361098778491", + "0/1738996183053482330", + "5/613572801427284405", + "6/2469968267997883416", + "7/1198508932911996137", + "6/7732917911283210426", + "6/7897416096150095706", + "1/2140834189525565891", + "1/8105372557151104751", + "0/1926316112310895280", + "3/8564745454825487033", + "1/6319215208492490311", + "2/5671376004991695122", + "8/8333140116441031518", + "5/1518384985418501835", + "6/3268065270227669266", + "5/4450486459951461405", + "5/6903828797741064495", + "4/6784520107746455804", + "4/8091610111794253144", + "4/4315913522583493594", + "5/8484452483174778465", + "5/4940885904179978265", + "9/4352816043794880979", + "3/1635480120569624363", + "3/404974278667319003", + "4/4826166863341949564", + "5/8405672971921034915", + "6/7342280188379138666", + "1/7343431620431733581", + "4/5035244349674541804", + "8/6538357678149136088", + "5/8394126462202683825", + "3/3465315339312066313", + "6/8310609348908832476", + "9/2115070951038354559", + "3/6052234555531379953", + "3/4922639049898806913", + "1/5850510583353445311", + "2/7421252918637210012", + "0/834405263842379530", + "0/4518757759610086640", + "6/1509064826205412696", + "1/6921946649275728501", + "0/6767674508877183930", + "9/4878148597711415769", + "2/8484382488075366132", + "1/6807161843973371011", + "8/8734814121313688888", + "8/4463735526029762038", + "4/1377291713465836274", + "9/1910095295739071479", + "1/9163664459890061861", + "7/8000187634514294757", + "9/4802832338580632109", + "0/2387506807976752030", + "6/7444702488118372806", + "9/4279549350469526569", + "0/4506752172236118070", + "8/371303286337396998", + "3/4012897761483117833", + "9/5112151897744948489", + "2/8344221532085041582", + "1/6101590973454290941", + "4/5106300963658417234", + "2/5194811287508650762", + "1/2143226385961211671", + "5/2966989870508698945", + "5/3192118692131575845", + "6/3785221552436802106", + "6/4264915263998055706", + "6/2378018192190825576", + "7/5360395696046039077", + "1/8053474246998414301", + "0/2225814550013773090", + "7/3187808354585955327", + "6/4801556999610631656", + "2/5368405701543666132", + "5/4177681535314588855", + "5/302030426977948945", + "8/5768482342780136288", + "3/2774910068907685253", + "4/8421286427001262534", + "1/5402324066234565591", + "0/5450791992253254160", + "5/9106224559710629415", + "2/3010372225331029242", + "1/9116768293335572431", + "6/4214510895196619916", + "8/6974094515735760368", + "7/1078929730561102057", + "5/7651763313580417185", + "3/5132798291335678023", + "5/4001765194607581245", + "6/614839544970968586", + "9/1122410054536558449", + "0/8516068067739013480", + "8/577843622776904338", + "6/4227090108143291196", + "8/2519274834198242348", + "7/5381742983103544567", + "9/8712625631318553809", + "6/4138773827054905036", + "1/1417689473433242451", + "9/5304454844678467249", + "2/5260472805323374692", + "7/3778859009117698247", + "0/1868237931664165960", + "6/5268834393335076066", + "2/3195260392287282722", + "0/2141494998974786960", + "7/5722722301047314527", + "9/6201623342959794279", + "2/3051970609331355692", + "6/8908856044584625436", + "9/4340518083750162199", + "7/2105843759857931137", + "7/5888811164214051267", + "2/5888237055865353702", + "5/7592034841669728715", + "1/4343221595035488901", + "3/6416969242395519083", + "7/992742288557030377", + "1/6947856507680991831", + "7/6980222372196145007", + "9/660115213436551309", + "0/6815212783537251400", + "0/6871117594845158650", + "9/7000756696369954449", + "3/5575636099597180123", + "7/7912356580016626257", + "1/6354300123677271981", + "7/7874852256638512927", + "4/6703993443433078104", + "6/8341637841944661746", + "2/2395497967628206032", + "5/2646975357435184535", + "5/4183408317789486465", + "9/8316812053721762509", + "0/7646624940245948350", + "0/183907857502366340", + "2/1039510241654690522", + "7/4739412149587899987", + "8/3119583419548593948", + "2/1059036574706667702", + "7/3522053891775970657", + "6/4895372387756869296", + "9/5134886346358421789", + "6/1906702211563512826", + "5/5161065409346175405", + "5/7064826677969378575", + "5/214177236959077325", + "7/2584281921390981617", + "0/6461372806987432490", + "4/3062239418206712304", + "5/1151795674389057645", + "3/20225664556850213", + "9/6718623985042872519", + "3/9077577477176826173", + "4/7310260224715609424", + "0/6395854213579377930", + "7/5577282093239101827", + "5/7685782854697470445", + "5/2975862389062298315", + "1/1469797162338836001", + "4/3255900920429697394", + "3/1835994572562990833", + "8/6058492061905130038", + "2/142414084189733282", + "4/8405489314705324524", + "0/3028657058954104710", + "6/1603821616363220186", + "6/8229236335289113186", + "8/6123289608880774098", + "9/5090570323004607659", + "2/3422270282540029732", + "8/5312889934898269738", + "9/7155037646270166979", + "6/8873409808654091946", + "5/8141412021126346735", + "6/7928043100833956346", + "2/1821097988407844792", + "7/2293257581644114287", + "7/2310621882526083657", + "6/2852275500459328246", + "6/5444296460388049196", + "7/3348183575678402367", + "9/6888820850006789089", + "4/6461960670287462204", + "0/5281133288621340270", + "4/5749824598023288524", + "7/1221928582747353287", + "0/1893652171856503840", + "1/754525117494052451", + "0/4497958496244923370", + "5/4941066515041883665", + "3/7703425575429022003", + "3/9001865128144543193", + "8/2799017205218875888", + "9/5398983379268517949", + "9/8002590021332597209", + "3/8727832217237132293", + "7/9103180319035267607", + "3/8814853464517807013", + "8/716607628241062798", + "1/2803961447831582951", + "3/4248313807636410363", + "6/4538377916758045996", + "8/8420939130766608068", + "4/2144263987876519174", + "4/61172549064631184", + "1/5829453442506092661", + "0/7873097826542233890", + "9/6405641838146479929", + "6/2568484313879396136", + "1/5681713545733573891", + "5/735315619620346155", + "4/5878964422962758534", + "0/9114810641719717370", + "9/3791564208510111869", + "0/1856377053382217470", + "6/1457740408426242456", + "5/4748184674918855185", + "6/5732901393231768626", + "3/4632243465461989613", + "0/1383750324070740740", + "3/6191833227484182283", + "4/2172581141013582164", + "4/7045961039347139024", + "5/7089578408855848915", + "4/5191621268467024024", + "2/2344680047446383402", + "3/6720193633695801523", + "6/5093781261806448736", + "2/1408939175752628502", + "1/2326532951113558711", + "4/8593232421020701254", + "6/7775601508407199396", + "8/3364902627332958948", + "0/8437028400148429690", + "9/8677348015047348339", + "8/4233739234952607898", + "9/2694102959900555799", + "3/181364387222468313", + "8/1498089808889317098", + "4/8862875536818418524", + "3/3525791612577328033", + "7/4257544651424114977", + "0/6795821652068830160", + "4/2549962075505747664", + "4/5685923738509281784", + "6/3766435235535895166", + "9/8760657023196482449", + "0/8148110889252125690", + "3/6260806237225927323", + "4/6084497388589527244", + "6/8788039270243795906", + "6/1052445403178357426", + "0/664175832066851580", + "4/8615730909363220724", + "1/6613044604330949311", + "9/2865514764914900159", + "2/2221105437507294692", + "4/1993182979350776354", + "4/2981453517967612954", + "7/4783025138317778907", + "9/8623003622590795649", + "8/2038174345194112278", + "7/5845953777167914027", + "3/7291899527180136163", + "6/4468064246693787126", + "1/7997844619095837051", + "9/6861904002063362829", + "5/1218337097065897135", + "5/3608448755355315215", + "7/173772703949996297", + "1/1459718474511782691", + "0/8805113680882661910", + "1/5166431478277135021", + "7/6915766721035614387", + "8/2522419413316153448", + "3/7338815437281089693", + "8/2509597935849732968", + "0/2476547208890462620", + "8/6534394813254565688", + "1/3696133284886713611", + "2/7514507093071719362", + "6/421186085797790036", + "4/6197857164424096094", + "8/6080893915008578378", + "4/4269298378545987884", + "8/7017370411993451968", + "6/4153915045854434416", + "2/4853430112364518172", + "7/7475411118734137557", + "8/7787443422588158358", + "6/5572992320889837756", + "7/6851267351205378757", + "5/8249259982063782225", + "4/4076831879646660254", + "4/562954568302416764", + "1/4160900049324395731", + "9/3802193144742469159", + "7/6275897531437482867", + "0/2731375422904745020", + "5/5463109976422547845", + "8/3989998979640668938", + "7/8568252636542750617", + "2/2508876686254270972", + "6/4720557394574193996", + "2/334952153970934562", + "6/3403763493618741776", + "4/1186706024373335354", + "9/5336163105949149129", + "8/2362302671698633758", + "1/4249674584697782471", + "3/798947683274570483", + "4/8070543453958594274", + "5/3276352939381180835", + "6/1955859102490107556", + "0/6092469035086739730", + "8/7151208671537107798", + "9/4549977091435895829", + "5/9184849915431668565", + "5/2319729864449927145", + "5/5598451732881470255", + "1/5303124725222601231", + "7/5386663539470272777", + "9/1873742662750852079", + "1/7736077787257784631", + "7/7611934133767752567", + "6/4152239607101875816", + "0/7088924578953179820", + "7/202684425822406657", + "6/1555834798980312156", + "9/6036999704847541939", + "8/6639135350631818768", + "7/8182195447162419957", + "2/7049965546005720142", + "9/2533705966722079589", + "8/5657226896149435408", + "4/5797812689481340954", + "2/3154254432283287592", + "6/8446888699628142326", + "7/992370222103801117", + "5/2801395902410527235", + "5/1445178862043284715", + "2/4810786262266688672", + "2/5656950874195901872", + "1/5819938522690889441", + "1/7801010784358651341", + "3/4690219138491622653", + "4/4970337013799049314", + "5/7450932014823577715", + "6/7740791060622184646", + "3/7706439894036673533", + "8/858152378224942798", + "4/4109340378207584874", + "7/1937670498145490257", + "4/4377701689447813874", + "2/3398313485754661212", + "8/5765733648178500208", + "9/5713171346023885789", + "0/8710151161472384840", + "9/25702208362782739", + "2/6894506726394630992", + "5/3106492193812707965", + "6/8211589499107207746", + "4/5126457399586514124", + "0/3877494189957255510", + "6/1259123034343271726", + "9/5592756183158830279", + "5/300490283109230915", + "7/7738349894064766977", + "9/1385009293177630539", + "8/5779227897853625198", + "2/3020476224201510522", + "3/821021436293024873", + "1/3531107376985865521", + "7/7215540933695213927", + "7/7760764920550024147", + "8/8718814061244833118", + "1/8293573802339231131", + "7/336281830625976847", + "3/4483414875238004743", + "9/1425114514332279019", + "9/4515335745960789819", + "0/6225584509397274330", + "4/4668782707485659134", + "3/2204841658305561853", + "7/7035846646752153107", + "1/523598252987007321", + "7/8008888208934852597", + "7/5540293268580830997", + "4/9186332036400624514", + "7/4543024145061153307", + "9/7953110935420927289", + "9/6262602453978395459", + "6/6749917091852113736", + "4/7371417046200586614", + "7/7223238637310723177", + "0/2395050561802155380", + "6/2383772126877023906", + "9/5573474578425335459", + "2/2305641621635680252", + "2/9067099211541160682", + "5/2618747745046262165", + "0/6162597630408884100", + "2/3485593427598862772", + "2/3708456344856391162", + "2/144033688736192562", + "4/6944240538578430844", + "6/9107480203185842896", + "3/4234527897594640193", + "4/5854395701362535344", + "6/8400159001992617256", + "6/9041775430629453286", + "9/4192013603992279869", + "0/1899970979779469200", + "1/9185825961107525131", + "0/6569149197856563010", + "3/120682071569470123", + "6/7639648587165833256", + "4/2832951763084104344", + "9/8974076899370431529", + "9/4802565074249784459", + "0/5365633810612934110", + "5/2980346578286950375", + "7/4413128709117248607", + "0/111332273246653120", + "4/5212589392872775634", + "5/1569598676177551515", + "3/8358984829572964123", + "9/6336209142368274809", + "5/1710154829901566435", + "5/577515401879178205", + "7/4226235391029714827", + "2/2754443209819987312", + "8/1893726947387321248", + "3/8710875916459471633", + "3/1938168502649588703", + "1/6509030612452732641", + "6/7659213507550028686", + "7/8763095921829349497", + "7/1733546621011217277", + "8/3990166296205853348", + "1/6319371061963780971", + "5/4930670278777118915", + "6/189749795989114196", + "7/4988856258829834047", + "9/1271065708974602779", + "1/8651477780713132291", + "7/9181181297541451867", + "2/9035450392825611942", + "0/3960128639328682920", + "7/6979624310961742397", + "8/4065519256609624848", + "0/7604682266301790990", + "8/4455073191428034368", + "9/516327841455560609", + "3/3070974174005101393", + "4/5036457925839246344", + "5/7369079138966484435", + "7/8698349421508003207", + "9/2543972424941234689", + "1/7205348126170278551", + "6/8116593095973637996", + "1/4832240706625131371", + "7/5208394240130356187", + "6/2884354689824751726", + "6/2382588035265562646", + "7/5592487491855013507", + "8/5946995853109687448", + "9/6239396427763533029", + "0/4449417089315085640", + "0/5554398559458152520", + "9/1293067402907106909", + "0/744346888595635900", + "3/3872129583662245093", + "8/5329919730755775298", + "3/3050845760193283223", + "8/791924862533485628", + "3/216214839771849073", + "2/4660044915035045172", + "8/7526502462603893538", + "4/8705184649592900804", + "6/6239028721065954106", + "9/2764166327757389989", + "7/336326409094997787", + "5/8746462757086269445", + "2/9056186386598695392", + "7/571524928428964707", + "6/4818842569381713676", + "9/8444142433863839029", + "9/5242334423202336429", + "0/3338394351794840770", + "0/5988140637350829640", + "2/8121927094164064862", + "3/5793365020885368793", + "7/4132960724069979487", + "2/7326424550814939592", + "2/7600935093691746882", + "5/859134279329125375", + "5/2942932461495639375", + "9/6202489082071839279", + "5/6092323696952852705", + "4/1641638469549600694", + "4/6284166620038466654", + "7/1801286527689757567", + "0/4395774892235431410", + "5/974237423423184015", + "2/3573187218886546312", + "1/2290048307533210901", + "9/953581606177668959", + "2/4376469645858959802", + "3/3733445186299046263", + "3/7894480858404864473", + "5/901277655612978205", + "6/370379529098331726", + "9/145169368815004339", + "8/6030148889324680778", + "2/829133100868062702", + "0/8726821150860780470", + "3/1951296221717730613", + "3/1694614593636676983", + "8/6687532758738137188", + "5/1093884124607588575", + "3/3381361178805416423", + "5/2314536487437756955", + "0/7766370075146139410", + "4/527389484389331754", + "1/6728832197198818621", + "2/5666086802860575762", + "6/2069622600969640736", + "2/3361952480605675852", + "8/63514914715401808", + "9/2990646221785097139", + "2/4013228210912311942", + "1/5857106815680428641", + "1/8663372489600612961", + "3/5135297894314829183", + "1/7349236077114615141", + "2/2991548871414310072", + "2/3474307973493084792", + "8/3665163430507825778", + "8/517065777212918018", + "8/1491902497101246788", + "4/5421803604029499514", + "7/4058544764288868877", + "2/3966878440602377412", + "0/6262545079697748710", + "1/338623800314488881", + "9/658131997932176709", + "2/483701758467973032", + "8/4077238288334469938", + "7/3069296909705004387", + "3/1751076049088054923", + "5/7277857483551450865", + "1/4686190781584724231", + "6/6445075834808473166", + "7/259990774085200377", + "4/6920118971293393844", + "4/7956309335296231494", + "9/1425997389289558709", + "3/8925980635027556693", + "6/627726377753647826", + "3/2236281770287129803", + "6/3974821166465991796", + "1/3178631086149138381", + "5/5399543496928208675", + "8/120012462433532178", + "0/454216625739754450", + "8/6787163034229240838", + "7/4964742196736765117", + "1/7913676393638640461", + "8/8454276308601909238", + "1/4612703097244288641", + "4/4736413079532341654", + "0/3140903778103583830", + "3/1670931150851490203", + "1/3269513058681306631", + "3/7562984795786247593", + "2/2418052543531788472", + "8/2749361628782121748", + "4/6715302327334770604", + "6/7173510333740850936", + "1/8107112093093434831", + "3/7183809003106105853", + "7/981898385906410617", + "6/3808940577700965946", + "0/4363616983742103200", + "5/2346887950518239555", + "6/1077016798992795486", + "8/5458011957658616818", + "9/542383535991105439", + "1/7891028594949084611", + "9/4710460413808456049", + "9/412005015840081299", + "0/4105600236666635660", + "3/1997246514715262113", + "5/4886324391800312915", + "2/9164333760156351952", + "3/2255077519262228973", + "7/3849906660155441127", + "8/706772989294011648", + "0/3956244387337855300", + "5/5846630543406078235", + "1/5095518084022036961", + "6/1917571456559803126", + "0/1444282168017799060", + "1/5722646392605752741", + "4/1955680116965109664", + "7/7092940486348319607", + "8/3512411176576977688", + "6/342095809070349346", + "0/6398371425973729540", + "7/1054298803054846207", + "5/1733401653491079165", + "1/4242779282051930311", + "1/1592161923048188331", + "4/6049532156959307764", + "2/8655581145991166702", + "4/399438378340104814", + "2/5085628327576700452", + "3/3254693805525326653", + "4/6391371473190793324", + "7/5363669486736452597", + "9/1662061131176043099", + "7/6833553403426682197", + "9/8558297026476766209", + "3/5611508807452564433", + "0/7776804889307629800", + "6/8703222588194245556", + "9/8865372569007250559", + "0/3679523714756831200", + "9/16629098228437899", + "1/5598758020936908201", + "9/6343574281710003719", + "0/8576671924329788810", + "1/3600904134636034471", + "2/4825808665168901772", + "7/7211182706627192817", + "3/4579710624623460083", + "3/1842309336504834223", + "8/3912874140832856038", + "1/6850070054031594051", + "1/7057194072147872011", + "6/3647870255840024606", + "9/2389426165580759309", + "3/3834572316987831983", + "3/3123188731691264713", + "2/4790934918684826912", + "6/348054402631750846", + "8/3638990068285608138", + "3/4126206910126202693", + "0/7496900540385186910", + "8/1328910488253728238", + "0/4437312224034742550", + "6/666744843545243216", + "7/8863502382155674157", + "3/6552408864062077873", + "4/3907649782078189124", + "0/3965411277877895610", + "5/6580385666707423565", + "7/755246638429344217", + "0/3766069854738669790", + "6/467574203891085646", + "0/5677031321413240170", + "4/2739693510492276804", + "8/868169300562458178", + "4/9140541221099897474", + "7/7705521291412780697", + "0/2033831568594342640", + "8/319285521906020388", + "9/2255723357871796419", + "5/4564817018053460815", + "4/6765732412546818554", + "2/5651844810291349372", + "1/8127082137929448131", + "1/7024304392589961271", + "7/5027730783509956227", + "5/1218903098499870655", + "3/3001574777560317093", + "2/1257820155374701382", + "7/2078696694486322157", + "6/7417247902758467486", + "2/6442297820900808472", + "2/8049658661498918662", + "3/1258753207174753063", + "3/7154521691337492923", + "0/8524266909660390490", + "4/775121736883739994", + "4/6149434999362947414", + "8/5994142159891478778", + "0/7020701375924349210", + "4/5591438386820897474", + "6/7373688601698027566", + "1/7411947675707280631", + "0/4603378324246093870", + "2/5328572470088799602", + "9/8889603373369783489", + "2/2640155594981957212", + "2/5638111985472494222", + "6/6632949487434945616", + "9/6417615023039345359", + "4/4420433107689405294", + "9/3261193360523687949", + "9/8151529080742268519", + "2/2518506652783260602", + "7/4864054810985074967", + "8/5297357507694995128", + "8/2711504927355878618", + "4/3425960501356503484", + "8/144747400983003588", + "8/2265348970376350358", + "3/7561475388387916083", + "0/3604417146255922790", + "3/1054880701434172673", + "2/5745292717480303942", + "3/5605010914074891913", + "6/7840134929714387786", + "2/3818327287904270532", + "3/688203144880428243", + "8/1977721529042379268", + "1/6661731186544010901", + "3/1387116541968804023", + "5/4399308183431728375", + "7/1796254415986388947", + "0/3550236534922020260", + "1/6999411397032976341", + "4/48828951002328264", + "2/5248433742040134132", + "7/1694648712048699567", + "8/7481807269716102838", + "3/4039431816670854013", + "2/5162790751283987642", + "1/940556474393642781", + "2/1240869362987573182", + "9/7705723405386960869", + "3/5976296906635233193", + "2/8501430030297323652", + "3/8423385274521872243", + "0/5104728487902097920", + "6/7618478798452988416", + "8/6348632088218719348", + "8/8194586416075868928", + "1/3542915393155774771", + "0/1086104341257357750", + "0/2345925167963617960", + "4/7731405011875941094", + "3/2497288247905822453", + "9/2381050717090103119", + "8/3608425395086646928", + "4/5932331136617025954", + "2/8125722730580362502", + "0/6998505019688906670", + "7/4612186942107053737", + "2/62134075074990212", + "9/4960608580173084129", + "2/5634565105774328162", + "1/3253597765743887151", + "9/1439744302679883449", + "2/6276456162474508112", + "7/8115442875205605417", + "0/932736021432439820", + "0/3134858340273419210", + "2/6944444651371698782", + "8/3066970195304840878", + "4/4022470079644344684", + "7/6583321529740277447", + "2/8977404519790621082", + "0/2636053158702330880", + "0/5468377297249509370", + "0/6848430247017769140", + "4/8825727034937523474", + "5/4326292597002533665", + "8/4557658350543572938", + "3/8116905130514578493", + "2/7448318568739548982", + "8/9039439966698699218", + "4/3477741758697288164", + "8/3018138086931081128", + "3/6641440007502131623", + "1/3977832195791253211", + "1/6855039837605057611", + "9/5452757484728394169", + "8/875312056732710628", + "3/1033463009526857103", + "3/7094194113937784233", + "0/1008310613509133290", + "8/1861442983252603028", + "0/4312434515318330670", + "8/4239866390366356148", + "1/325072800551330411", + "4/6658827659960300254", + "0/2699748385794095370", + "2/2748710214622461472", + "5/7540657305634941865", + "0/7798447377942690", + "7/3621644345067921057", + "9/7305283064232185709", + "9/6015545087251812869", + "5/6479822878657193895", + "2/2883391233340004902", + "8/4054165151539833808", + "9/3515636814879529689", + "1/819447960032712841", + "4/1803243465058042814", + "8/4389693323260924008", + "0/1722614265537917240", + "2/259565388981237582", + "7/4565541047352487167", + "0/8919846134426001840", + "9/2721811457938350499", + "2/886057799935114512", + "4/1090474905405479754", + "7/5886598374098774817", + "9/4807021703227997439", + "9/4829944706346175719", + "5/2743905066072938675", + "6/2899988609449326616", + "6/2030800248052642676", + "9/1637634195504399509", + "7/5982059694620103617", + "2/1076019316624467752", + "5/2648637724352168075", + "0/4610205575609865810", + "5/2899423985665904145", + "1/4090920800054049841", + "2/9113351576540908882", + "7/5812991355044219907", + "0/2762378205544686260", + "0/230202626758789020", + "7/3714113790510676567", + "3/19988700394869513", + "2/2699043018121883892", + "2/7897353998896698972", + "7/3252177732926636337", + "2/4112959931130261732", + "7/4651232259934869177", + "9/3166893592725265229", + "9/4865140289762274259", + "3/4178308177512802003", + "2/8963902730236334412", + "0/6428390438635892460", + "0/5852786492213752020", + "4/1817189871780993644", + "4/5263233711680159014", + "0/6579454328364520340", + "4/5471487049454825334", + "9/222910201852016629", + "8/7310464330636858538", + "0/2880512674266677230", + "1/3505528230740792351", + "7/570511947064352877", + "4/8356514890186998954", + "7/2420396570079432557", + "0/1245551547183764380", + "0/4640111154197522330", + "6/7489192264184863366", + "8/8922451551030230408", + "9/3168109075276903609", + "9/2055081844386641419", + "1/6055280334695119571", + "5/231931955739550455", + "5/5199170736416620595", + "0/1229661039094377790", + "8/1635314435289224118", + "1/8675525667095239961", + "2/4665767980166602652", + "7/8269945708481766867", + "7/4936642069086177097", + "6/1324915746917125316", + "1/2506738040855989611", + "0/5538041866809564590", + "9/3493112750320977069", + "3/661600091288288883", + "0/4507193832473717440", + "5/3648020358587605305", + "8/4644701493295481358", + "6/185921207412044546", + "6/465275710172167896", + "9/790974552774534589", + "1/1291171676559447161", + "1/1659655270757096981", + "2/6327605299194554072", + "0/7212360044391798630", + "5/5982879696205554325", + "9/7234150681042522839", + "5/1873027718107222925", + "1/2222184718401668421", + "2/1442779372064889782", + "5/8297759857321036795", + "6/891769111166575206", + "9/6731564227711139519", + "5/890436893221275095", + "4/7438574361567296424", + "2/7925492256166285842", + "3/4911021823516220193", + "0/947438913772985300", + "7/1819783923526497037", + "9/5366582138882559889", + "2/4781335451428744942", + "5/4607531251409209865", + "9/4288052343690500049", + "4/5196880032365723214", + "8/5156295002195182138", + "8/4115455341945047898", + "2/4500342379178992942", + "6/3077679757647692766", + "0/5567045856309322330", + "0/288370486109484050", + "9/9092568256463310419", + "2/286197236240570072", + "5/4144856307704418605", + "5/2711962630403364715", + "8/8203991097892004088", + "7/360518965281118547", + "4/1514925395959469254", + "5/5486173170327792755", + "1/263286407352862621", + "9/1589086408793540489", + "8/4477991031580266048", + "8/4413453073745708218", + "9/3996562499964606869", + "5/2897285871179135515", + "5/4536628890164178885", + "0/5434764071775478040", + "7/2844645667583617447", + "5/2101537987357086155", + "0/1031427810505879060", + "1/6479090333116559171", + "0/1459459205160396850", + "8/3998809761425747718", + "8/1085678911436141298", + "7/5780075582139287547", + "7/3633630746826254717", + "9/4610497070564574009", + "1/6244056347770385491", + "2/7017347069921951582", + "7/7233407990773010227", + "9/2273076704612497439", + "4/5120849541149636934", + "3/2058653157811018923", + "8/3801591671266351118", + "0/1558924779742579380", + "0/5222918352300161090", + "1/7685736291345516441", + "7/3191319983262881117", + "7/1004854449098419487", + "4/5959261440404678264", + "9/7214936208898121099", + "7/4966336986572319047", + "0/7573206486630856790", + "4/5818055956712534744", + "4/832640011377056114", + "3/5086249136665929873", + "2/2515613415964928362", + "9/6472029846267288039", + "8/8541397720879082928", + "1/4924632861143038551", + "4/4148961757191340114", + "8/6433309720274860798", + "4/6841294625329622154", + "8/7655519722378933038", + "4/4995268633178697674", + "9/6091365634235562569", + "9/9207121887836531739", + "3/7076868229670232823", + "9/4485345209065174849", + "3/7502195069513595193", + "0/3628536396329063340", + "8/1430745652578634808", + "1/771944738245200981", + "6/8103619027995847286", + "6/2608339869906122356", + "9/2111947526911871199", + "1/6794616872164402231", + "3/8341253897191133103", + "3/6654342004598497753", + "9/3311460001507841749", + "3/8889328306314608833", + "4/6693500342745230314", + "1/2353827007892119101", + "4/7818687744359000894", + "6/7759832427182316466", + "5/8381395526261660715", + "2/5094428986219464022", + "6/4441009560416586336", + "6/7731846190313712246", + "4/4104458335426488434", + "2/6962510497192536262", + "8/5132714606972472278", + "9/2286214955254649969", + "1/6095560006724677761", + "7/6057334180700051747", + "1/4129303098824033801", + "3/3818502379905271693", + "4/3843200155054922564", + "9/5666818133082771469", + "0/4850516755695240320", + "0/3854037895687150600", + "1/7402903259523183791", + "2/4213225392696186962", + "7/3699350283001042097", + "7/5873811225115650707", + "1/8860791254960733661", + "4/6388609412759201144", + "4/8440218115013757544", + "2/5343270268749733822", + "1/2660545661567939761", + "3/8959540836585605393", + "3/3109165577329658103", + "6/4912829381584752076", + "5/7935124697668455435", + "8/7959955658434228628", + "8/3543306943738287658", + "9/7339537585234758479", + "2/4997619940273737562", + "8/5690284589639572238", + "3/7546722064006707723", + "9/6480863713336581509", + "2/7749575512013409272", + "5/2594656766897487475", + "7/2633906408252400017", + "0/8928535482587791670", + "3/7933750051417673183", + "0/3266405811611944060", + "2/4424141487789832022", + "0/8463337257994434480", + "9/2045697281859302999", + "1/3235958220138042671", + "9/632609457686638819", + "0/3186103325728321880", + "4/8305853267344806044", + "9/8902709332565561829", + "0/6591713808649556440", + "3/8728265745846084653", + "4/3552212544098079374", + "0/38716312618064220", + "1/9493089140174401", + "2/5489060948691966612", + "4/4314176638769213874", + "7/1291635074720847927", + "7/8478946407431234317", + "9/4742451041944598029", + "7/7793939828825439157", + "8/6464167506179560098", + "2/862286702990504222", + "1/3212098387628876311", + "4/5329875324130741994", + "4/2838821058084527534", + "7/2936011985416284717", + "2/158087043096208012", + "3/6367513288924479393", + "7/5683542692334963827", + "3/1914698716154970633", + "2/3633000978717992592", + "7/2691996046269275567", + "1/2729494919026563201", + "9/382189100386668529", + "5/8505084587973643605", + "1/1496974031523688831", + "8/2053537984447789378", + "9/3935319934786822729", + "3/6003247051917616533", + "6/3653916571765917656", + "9/3744534672477546769", + "3/4743173224485814973", + "1/5989502450625304221", + "7/1206033872483429307", + "9/6395189531924536559", + "4/6775238595644289784", + "5/6216809926393625805", + "3/6252501164420523863", + "1/926448783139633501", + "4/4894605842904896714", + "1/7465997452267241161", + "1/6466550338626112971", + "0/8145239982446012400", + "3/2012001680120982693", + "8/7193550962475965408", + "4/8823305295381948444", + "9/1944454509878091139", + "9/3800055909908704419", + "2/2072453306470646362", + "7/8554969899527338987", + "3/6303876453618645263", + "9/5666044900401864289", + "0/3920759878220224740", + "0/2658638009989801420", + "6/5599747489035718326", + "8/507072897784708688", + "3/1574370988387470753", + "6/7364308243576589626", + "2/1337997952239146472", + "1/3800907618544633671", + "4/6023572166732847624", + "6/241455493440226516", + "2/7363400703615049382", + "6/8158329566826974756", + "3/7255878407025838363", + "1/7742752657999866111", + "5/3425897447856487775", + "9/4978795701414523399", + "7/4890897435894363497", + "0/2059438623989414320", + "4/6219287033410447794", + "5/4414132513890653705", + "8/2606616029063974598", + "7/5076337331154749717", + "9/3285851601230784349", + "1/2540058971382283851", + "0/717831117442830240", + "4/2377889848246545834", + "4/160523571497340714", + "1/2701143881512740471", + "7/1976526660909347777", + "8/8508482242321879828", + "3/1806911761443186593", + "6/4895780798664647326", + "1/6547551572577649041", + "6/8399106235861791296", + "6/900750794054427966", + "9/741722608914754729", + "0/5214053721419748950", + "1/8549418581988304081", + "3/9033925806681324513", + "0/7736005844567421570", + "5/7624796546513313845", + "8/2814993017888526888", + "6/3146106149707855306", + "9/4320990183207558599", + "9/3878631469088534959", + "4/6375734095001892284", + "4/8802691737569211464", + "1/2576775563010082381", + "6/4130911766934300406", + "0/3473173188849829670", + "7/2926259751685155417", + "5/7866225364419307125", + "3/30806425399249573", + "5/3758152488676158725", + "3/5170007545915290253", + "9/2754404714226223199", + "9/856945445812537029", + "5/6845468359738936835", + "1/8151932786644209021", + "1/4071849336988435771", + "7/4942633818553105087", + "8/6717070080123000828", + "7/1244822227063302997", + "1/4721693241075250441", + "8/4506251590529963688", + "6/7240285367703392376", + "1/3551233905788594401", + "0/3602150352583400360", + "8/5716497419993227768", + "8/1767611141600768778", + "2/8468873606880641312", + "2/4658786108981502102", + "1/2889915446081521221", + "5/939490792208247725", + "9/725122919254004089", + "7/5566508548981456917", + "9/3980742647550980279", + "1/5886972519132295671", + "2/2148359028610094972", + "5/2690878523179408165", + "4/6127610571954913294", + "8/3016990698389864538", + "5/8023040633902079045", + "4/3784807617380451854", + "5/7595221049861259485", + "9/7165535102438943059", + "2/4977460070548850792", + "9/3122564191063171079", + "0/6746765972629657330", + "1/2360867422126063781", + "3/7316512553443830053", + "5/6378994777822690455", + "7/5825403178989243297", + "7/1995608879071424327", + "2/6259465237382403862", + "6/4131427361739967886", + "7/7593542861359167587", + "7/5152351496489571067", + "3/1579004706764058563", + "4/722748758565259364", + "0/2737843533855360520", + "0/4763764133216160860", + "2/6876883837993203062", + "5/219130334806383305", + "7/8864436364234861697", + "6/7820382021477020366", + "1/7667114473979999361", + "1/8729444168338377051", + "0/4231134736762549290", + "4/487673722199765014", + "0/5717170051348244430", + "9/7955994639097014449", + "9/6574013436164343169", + "7/890472312393861217", + "0/965956219368207610", + "5/8008057091915224605", + "5/2617433940491049275", + "1/3309644128515586661", + "6/1071236964245391206", + "7/7640084920286921837", + "9/4609596256673975979", + "3/1713611303078213823", + "8/3052158562973043518", + "9/9160233000751795089", + "0/5715994959568482740", + "0/504324931973920670", + "3/2032616363747049053", + "2/3811179467643340942", + "3/1489927746429287213", + "1/7356772574476607501", + "6/3622285036513910266", + "6/5390410616597310056", + "4/8120471039602742764", + "4/5571408976810192984", + "3/7300421241863173093", + "0/8744275514108622980", + "5/8321053381386865655", + "4/2380162373126932074", + "5/4851664977415615195", + "5/7007861937010902345", + "6/82014185215631776", + "8/4205797659545733948", + "5/8846227157912943525", + "9/6627548726502425839", + "6/2556511087914912656", + "1/5016042186266963001", + "1/4834562294022263451", + "5/6129706385986831855", + "1/1229759664494945981", + "7/5519589229748703627", + "1/866512558730986141", + "4/2549759072059765554", + "5/1937299238029817245", + "6/6636085990682951016", + "1/7876018362540586531", + "6/554149425581897956", + "1/7333155670535245441", + "4/1834725759775046094", + "6/7923272205925076976", + "6/1406626456356005546", + "9/6517210964966486329", + "0/5060277260065032810", + "5/6267323863885777085", + "6/5901472943696481216", + "5/4614097588272231825", + "1/2269224592520751771", + "6/7277362401371834466", + "3/6333394867182102253", + "6/7055280165778781216", + "2/7222160953425629862", + "3/641796884715615963", + "6/5908842435843026156", + "9/7835904557795487779", + "0/1682013090724119090", + "0/2102292881932140500", + "2/6115639441090208852", + "3/984212584655441443", + "3/2925983232295463253", + "6/6933791569656185106", + "4/1506756264931307954", + "1/595238247861797411", + "3/4757877242808799803", + "1/3366896327256335261", + "5/5309387392721111965", + "7/5984202995665930227", + "9/4886137675330910049", + "6/1579117838854218136", + "0/2167457125422507480", + "9/6961123930615749409", + "1/130496408418468961", + "5/3088064642985298515", + "3/1975557881579879703", + "0/4663775992715899470", + "6/7386828198745469176", + "1/6426378192705592171", + "1/809594022250528131", + "3/1800447232841270123", + "0/620280262008964480", + "5/648876553405728395", + "6/8761719588449898786", + "2/8283728993664499602", + "3/837694440387998433", + "5/3519827358979131835", + "8/3207400458115992018", + "9/3834829881771501149", + "0/7701542287009585640", + "6/4010164313337948116", + "9/7749442571352080259", + "1/6829560588197088741", + "1/4538590612284718051", + "0/2895280481767493760", + "6/4163305077242275596", + "6/39882451893315076", + "0/4313685998369196840", + "2/8124505627393927802", + "5/7686752410709379875", + "6/4909135823989563726", + "8/7511556035224181588", + "1/3552505186024898361", + "0/7045265417833372440", + "7/9077396089792686967", + "7/2574314011057272167", + "0/8362946598482035900", + "4/6387901398817364724", + "9/5922778757917244319", + "3/4503562569152568183", + "1/2184646705646242691", + "7/8634702908282577017", + "1/7701324510470904151", + "2/7646379167470448462", + "3/63381111249527413", + "6/2682023633440717966", + "1/3106374678377841101", + "2/8717566045527862012", + "1/968955350215980841", + "3/3815992723203505783", + "7/7002928288512570717", + "2/4015249093643561392", + "0/6931774109206710060", + "3/8977173847732819713", + "6/2791598538316852316", + "9/3421710632752758189", + "5/1808473054309891825", + "6/5155600652534049256", + "6/5555620994673468266", + "6/6219424486256037136", + "7/4714031074648211247", + "1/1946521799433149321", + "3/5306361997149740513", + "1/5858108278634830531", + "3/6484725346729263473", + "6/1255917915447998426", + "8/1390922417227109968", + "6/6002586568135554316", + "6/1371872324752136566", + "1/940430863078135671", + "1/5153940746945573521", + "2/9112614655695216182", + "7/5534982556609206697", + "8/4724745781869348778", + "7/9054532099974058097", + "8/8555216507474591858", + "9/5779968271311815799", + "1/267976352872208091", + "5/5209934227616034915", + "6/2842649076459145016", + "7/3938502866467459727", + "6/8431179058184038566", + "8/6937942645537764928", + "2/4276618802626941282", + "2/210052847735807182", + "2/5888275937070829592", + "8/8126349752283172928", + "3/2315997181056358793", + "9/7897630748687824919", + "6/3320959868934455646", + "9/638879870663441649", + "9/7054581559819655459", + "8/905074056810186358", + "9/106617834206412049", + "7/8238648696947468587", + "7/8848848191012837427", + "5/4399578747590529925", + "9/3988089278327346239", + "3/7128322386476475963", + "7/6233863979779266067", + "3/8262214622134530483", + "9/7001580066701711549", + "0/5053537939493519390", + "2/1508555204972377242", + "3/1691941150466816513", + "4/8035481501333610604", + "5/5665946838590466085", + "4/4520802565258108164", + "5/7113308130871604115", + "0/536303363180131160", + "1/1053718519037955931", + "3/4956665404749893653", + "1/7253033858367635031", + "8/8883873274861751868", + "7/326609632064251477", + "2/5942767694006351882", + "3/2671801726816697133", + "5/430047758954599295", + "2/2857199654048629242", + "0/4878505357788891150", + "9/2993791952638281779", + "2/113984581738308512", + "1/4512021503177430051", + "4/7289255131926215874", + "5/7551739865854879875", + "6/6111514162810893566", + "0/3316027330639534410", + "5/2369893486626534195", + "4/2628198839629467594", + "7/119151232827945417", + "0/5878968135684076850", + "0/7366460052533091580", + "3/5986298662802155253", + "9/2240377770020712629", + "8/4086811793227207778", + "5/425165392213076065", + "1/1222092808882420031", + "6/6271352936114275716", + "3/3254516276420800773", + "1/6995762015186280451", + "7/5653035247814133697", + "2/5801155583447752902", + "9/1741640461046874669", + "1/658595318785034761", + "8/6221240640038995108", + "0/240773120517557880", + "8/6869441511279073668", + "9/1474133588369447899", + "1/4875228375326526861", + "8/4727633905296470248", + "0/4790081704568370500", + "3/1461363051938795703", + "3/1677430162201157113", + "9/6113046913888347299", + "4/3354814635016136694", + "6/5718953893216821646", + "6/8165276147738709006", + "2/402648680457781162", + "7/9089866213157755597", + "5/4005757247836429745", + "5/6075550936918641935", + "7/1306517510793250037", + "1/2811929640298252191", + "1/3450959251492673401", + "5/4118230638890744645", + "7/2103983204457149487", + "6/2801077619709639036", + "9/6760880204492282079", + "5/1079642278387042455", + "2/530571703733657022", + "3/7475295516870535363", + "2/1662782210195376942", + "0/7716196316752923290", + "0/2586415430842779950", + "0/2303733979792843000", + "9/6915330451458278079", + "6/2283428619473742266", + "9/2165376846605960069", + "7/7562728684706525577", + "4/518913597396045494", + "1/1115979378793808761", + "1/6583251678831599521", + "3/1649483199352785553", + "8/1346154175809670918", + "7/156091862841045137", + "6/9082708495246760646", + "9/4601189726920685759", + "6/946439597855550536", + "9/1749413212546343039", + "5/5903043141553358645", + "0/6249656528549748430", + "0/9083336656975073230", + "1/7538494219521907181", + "1/7564423958290859131", + "3/1048166239328874023", + "4/4817068496006848924", + "5/3585173543592582375", + "9/1527603262565439879", + "6/3350535402021563546", + "6/3378233792227722616", + "8/5432666084543107278", + "0/672425528829057870", + "5/7805378164355570775", + "6/5176892008738090746", + "7/3934305875901031207", + "0/1517021781432004670", + "2/5679884278952881822", + "1/467876739391020351", + "7/1108324869658085127", + "9/164647158538518679", + "4/1567990021299526604", + "3/8722807344296072793", + "8/1803973447359942268", + "3/8477768114641947703", + "4/6108766480381707044", + "2/6584912611613350952", + "6/787014588224482596", + "1/6863109335251963091", + "7/3243957917999101577", + "8/3264716651784818738", + "4/5973115356480461014", + "2/2189407179745141182", + "5/2083804273720774555", + "5/7942799111778464165", + "3/6663676831222881343", + "4/152908931209812034", + "2/5358749276939748122", + "3/6702353458033383553", + "7/8728933585864472337", + "9/5244245350824727179", + "8/2522195078604792898", + "1/8574585983109150341", + "4/6126219640569945994", + "7/7469765638078918017", + "8/8926619600183612098", + "5/5976197835917260045", + "7/3967546654618180777", + "2/6157169189286300642", + "8/7716831136697856018", + "9/3313418988483555279", + "5/1680734546289186825", + "2/1264248206895332882", + "0/7844856194846375700", + "9/8610319298605393799", + "3/5100416502338582813", + "3/1559201159676596053", + "4/878636834254571104", + "0/664390785217845070", + "7/5209430350995621687", + "7/1576968564832378437", + "7/7729246458367699487", + "0/2567503970694987750", + "9/1496162338094165579", + "7/6025511837132226257", + "8/596975807886285038", + "0/3658379276245038050", + "4/1173931401109141524", + "7/7898861365768784077", + "7/2464929169441080717", + "2/6275408169799441612", + "9/8017812001254655429", + "8/8012776653689778058", + "2/2595082417684813212", + "0/5948790368274437760", + "8/7175071267176594918", + "1/1230992333702640241", + "5/5676438052907528485", + "6/7008459181599916096", + "8/7438487591760144548", + "5/2257373388414581145", + "4/2609029497501842494", + "5/957934436210881775", + "2/1791447513517704572", + "5/6605858442773846905", + "7/4122484672565554237", + "5/4357203587379910095", + "4/1457612915198682374", + "7/4413508981436256927", + "7/5097434243769810997", + "2/8119019652623178752", + "1/754408713338427831", + "9/6766671337304169519", + "6/351355053459572696", + "0/3314921070943113540", + "1/7443901930183559361", + "3/4910260645204753493", + "9/637974596696798099", + "3/522208060684489983", + "5/1368284020295522005", + "2/3688512539715076622", + "2/6360990702850738382", + "9/5203882388792445409", + "3/1872898159238913163", + "3/4424679446232996723", + "6/9726174788738236", + "2/8141508187800374662", + "7/6444930865315835797", + "1/2760281303137876541", + "4/1778591055407605204", + "3/6334560914448587313", + "1/6625759305504211061", + "0/6335969086512850070", + "7/4693812212409630557", + "5/5557524141635303615", + "1/2022267982343604441", + "8/3958315352638945568", + "0/5441824940427984270", + "6/4640998284684350716", + "7/5804499536389555217", + "5/46894910170888775", + "3/2612541823905582103", + "6/6870911478905495196", + "4/2006922534595088364", + "3/5669049866595403963", + "5/4518917591424024975", + "5/8572597732677447115", + "3/3821519359522885823", + "2/4083530246457195262", + "5/4598754574942756525", + "9/2785166872329277319", + "2/1652140332365972842", + "7/8001527795767513917", + "6/5032054631410393366", + "3/8542931992697243673", + "3/5930316325962238983", + "4/2974378545170413764", + "0/7392969633927339060", + "6/6268388916150634286", + "8/4571052975356585878", + "1/1590785492822321371", + "7/136427201476998957", + "9/3090883363111396179", + "1/2046934564453848521", + "6/2317985355780981436", + "6/1696211949715260356", + "3/7275632761751355193", + "6/4485411316337033116", + "7/7396759061114691907", + "6/8830719869335133526", + "5/1639271834428494125", + "2/3154016548533978842", + "6/7023142637230127796", + "8/8193418077112802448", + "9/7168363981223949899", + "5/2085700345265064865", + "1/5455455020583211251", + "3/4579363324627461453", + "3/6927170614088194163", + "4/4640076122735076344", + "7/585326911250185477", + "1/4437300918502001751", + "5/1514112769403183895", + "1/1954466315994769661", + "7/2385587358967069857", + "4/874786632111657384", + "7/9184366020572776847", + "5/1935617820497549225", + "9/8882854717413315869", + "4/8230650062989714974", + "7/2393228363066015727", + "8/3390669323185946128", + "9/1736188192768195529", + "7/7797709686629252997", + "1/601670025961638511", + "6/5321433572319865946", + "8/7749181459734728618", + "1/383224893846168611", + "0/6620620780011977510", + "7/448008915535927557", + "4/5795553392602678984", + "6/5838352761580527956", + "1/8039943962664312551", + "4/7326785546328314734", + "6/8032847578823233766", + "1/8639715320865523811", + "1/1995551907338074831", + "1/4133179172424272321", + "2/1732093664215583462", + "4/2353879730810265154", + "4/4165174791252842524", + "4/4352024335489565034", + "1/7886237444876158931", + "3/1216026490221825383", + "5/3370000588090594865", + "3/54161351892430503", + "3/7344782557202619073", + "5/4544630579919428945", + "5/3588447675299432845", + "5/620977428041422675", + "0/1845202981803098700", + "7/600205285056942397", + "4/1026399454137602734", + "6/8187381323346721636", + "8/3116118276498963218", + "5/1534988718244207995", + "4/2513891215279609074", + "8/3641454895223827718", + "9/8431640610786606259", + "6/3077392948423452176", + "0/5399012092552750060", + "9/2109063729527277389", + "1/4467594084507801471", + "2/949823897221888712", + "3/6202102267734934983", + "9/7527842073743511519", + "0/3919925098503909670", + "5/8043210327216205115", + "8/1842653283934174138", + "0/1135354269388833310", + "2/4824326200103222702", + "1/8749616311384988721", + "1/1269522819588644851", + "8/4493342383962244538", + "0/4402904416840208320", + "2/7086276696764733402", + "7/29533081765335847", + "2/3817731981479809152", + "7/7501510311343495737", + "6/4332399612469804516", + "2/8377970859453653182", + "4/6439316893655830424", + "8/4856957093545468038", + "3/3951530788401416123", + "8/2969716953789704688", + "3/4019683323068446733", + "8/8000809000998704378", + "8/8905306250517450538", + "9/6206759819801090699", + "8/1959117473101048938", + "8/7520118658812416228", + "3/4552246683444103733", + "4/7136088785339641984", + "2/4405982025368774722", + "1/5360086693103195271", + "7/8993387866018998867", + "6/7591532357962851066", + "0/6479475478460736290", + "2/5722823129187572522", + "6/1448137627399635476", + "4/6376513150393262394", + "6/1866340563428877196", + "2/4933398581565446662", + "4/7750669406647227124", + "1/2674591775345392651", + "2/7548872870222180522", + "3/1974284834456564583", + "3/4985462854759152263", + "0/4321559250499143480", + "2/1194569463410626362", + "8/7336405458627884298", + "9/8652386272053452739", + "0/5381633601546830790", + "6/6986288847990194356", + "3/6426502150879816783", + "4/6148327320749482364", + "9/4927262096730304729", + "7/1636107407574317447", + "5/2026523486134053435", + "9/4532709568201554089", + "5/4276496504110118155", + "8/2018829332630501498", + "0/7635031413657095510", + "5/1875855647264516365", + "2/2788801838792348562", + "6/2118314237699492706", + "3/1811392964948231403", + "6/5136992212789411836", + "1/5160463170664996711", + "1/7189606295007269111", + "6/3700101501392800576", + "9/3354171108207276849", + "8/4311805997323140758", + "5/7067609553930623325", + "5/4968597639820744425", + "5/6839839331830586735", + "5/6935549312584578425", + "7/3551482608657491937", + "8/1531581385397286058", + "2/8768681402994554622", + "3/7876624312993178723", + "0/1446024252400909240", + "9/6452665947792663609", + "8/2810887983466890568", + "9/1252953835948818399", + "0/8780943346787222420", + "6/7307572404265408406", + "4/4372607556315731904", + "3/4890297422426932343", + "0/5358399855452256000", + "3/7317656960412906823", + "4/7481921185322054694", + "8/1977770969343439118", + "9/3536673518552324919", + "7/8790211834400992287", + "3/681353271520015263", + "9/4851762689788142999", + "3/382286415843575863", + "1/4838839860991819121", + "2/6336483251238202602", + "6/7201367406879353296", + "4/2234696125985595434", + "5/6863943402561342715", + "2/1247060864022452092", + "9/3357852109567617399", + "2/4862453003268757112", + "8/3570654707925743908", + "5/1356313125574925715", + "4/4928690293147508164", + "4/8790036735506479314", + "4/5207718481575510074", + "6/5018563080382109526", + "6/6986155357147700086", + "8/4838126842948129358", + "2/2061963051686654052", + "2/6178757452338375272", + "5/2556208446708616025", + "7/6725200158452144087", + "3/2855115974520683573", + "5/2667448970521023645", + "6/8772445562084793666", + "7/5871901194711550237", + "3/2363804125948038953", + "3/8607764025783609903", + "7/4896705039780830627", + "2/3492634749513498032", + "8/1585424170508932788", + "8/3780914503418980538", + "7/5193466635107142977", + "8/6840433325510261498", + "9/7025394767389572589", + "4/6469452289291317784", + "0/508186040755283610", + "4/6063262177566326494", + "3/4818592442832280693", + "5/5030699323098985115", + "9/6079230505591443749", + "3/2960243623713133013", + "2/1840508207918636352", + "1/3707768113347633771", + "3/8807094627685044913", + "0/2027872515557151440", + "2/654619936389571332", + "6/5854455663340435106", + "9/2422564059235525899", + "4/7594896980748997894", + "3/1040276734267195063", + "2/3074886997869843352", + "6/5655960737218108506", + "2/593269312405780082", + "6/654625047230655106", + "2/4513525546695857362", + "9/1414176011026699909", + "0/5563308683522673570", + "4/906403057901589554", + "1/2790641774747553681", + "1/1230800336477898341", + "3/8530681868426332253", + "3/7550418202683987413", + "8/1991305943873918018", + "0/3445188753257345790", + "1/767279254305441631", + "3/8854192597628925843", + "6/42429505853474466", + "5/2659535677584024885", + "0/777112453947608090", + "8/6035914950651843048", + "5/8494343394699630675", + "2/8152009503943464732", + "4/7529376085391518444", + "5/3586316612402584815", + "8/6322194696472093628", + "4/4726674916309060634", + "0/660175704950200970", + "0/1559125530014622450", + "7/4856745643822702607", + "9/4012482164063195449", + "6/1467643580427806346", + "2/4480977428454133032", + "2/2066675240237639342", + "3/5103442855114150413", + "9/2100022498073276889", + "3/304075888960918803", + "7/7921841888885087787", + "7/8170474997448299027", + "2/3042216413341293692", + "2/4950319200068700502", + "3/5947108855654039223", + "8/3788194291072923678", + "0/5508826160789603870", + "8/8977748777623388938", + "3/5131246365062499933", + "5/6821746299350863405", + "0/4190716126965781520", + "6/1798289413809586666", + "8/7410556474847547858", + "2/1894670422235825672", + "1/2731259848949386721", + "4/8642222294459405754", + "3/5159948761327159373", + "6/2719907109263980656", + "6/5512625861451735776", + "7/2544122692628433827", + "7/3793472776023018657", + "7/1853614806833290787", + "8/9011597989540234268", + "9/5939983684789729629", + "9/6440115343928530259", + "5/1856354650516367955", + "6/6329345465374008296", + "1/5812267062406718701", + "6/3825244108707256086", + "9/8548783494496846989", + "0/1308030872403459710", + "8/3903183045636402398", + "0/3376936358862619770", + "2/4717877532956708232", + "9/1196127004269500409", + "8/7398775244299906378", + "2/4368545358566996132", + "5/592931803813176925", + "1/3304175886794352241", + "1/13220713505035211", + "7/3696960306492083587", + "7/3763028558459685547", + "7/540474236660219487", + "6/2177834594616273846", + "1/3666246485013702201", + "0/1370430444812597790", + "6/807776524774435796", + "6/2352608430028143946", + "0/1363806862831598710", + "0/8822344319338528290", + "1/3344280995586773341", + "9/8768062445452279519", + "8/4006222587193476548", + "4/9158539697408509744", + "6/7249220595088410876", + "0/3801901023809895850", + "7/363838578976305237", + "9/859571356837149669", + "0/694908503894739590", + "6/2327209970266557066", + "9/2691856628900071569", + "5/724303957059027305", + "1/8745815653970237081", + "1/939552516132059651", + "2/292893279581412592", + "1/4404281802102057231", + "1/465318402175919361", + "2/1184247335348348692", + "6/7384425892114359996", + "8/2621314764675879658", + "4/8556912837760981044", + "0/1597810835174658540", + "1/2551712776723735291", + "3/1661286501457787513", + "9/8711440297586757989", + "0/5944530833691850320", + "7/7479825741008560747", + "6/8793863054563311136", + "0/3622679873795150790", + "5/6405371962866090785", + "5/4810712245227792815", + "2/4595481744798114512", + "9/6268978886177373319", + "4/6270959030209594304", + "6/7581848414837493326", + "6/8794605404335385606", + "3/1843776643333490803", + "7/8708743380661768717", + "4/178836020120729814", + "4/8469821105984540114", + "6/7886716340731303016", + "3/8073566483527818773", + "9/732659525097616059", + "4/9214443118199315264", + "1/2762701837133721311", + "8/6908553259074133298", + "3/6828368342012444923", + "8/5342702352173758148", + "7/1027810996287678217", + "5/4864853546883004485", + "9/605093680071474369", + "9/7055395986684698569", + "4/6850314667601921744", + "9/172429477586741419", + "8/6840868700651513078", + "1/7881213198094493431", + "6/3012844598263053936", + "7/5771528536771181407", + "8/3508437303919503498", + "8/5316492580828141998", + "0/1553014518862768620", + "3/6327293306673778133", + "2/7177005051602951732", + "3/5755743183369980203", + "9/3201305241326417699", + "2/572682879212580382", + "6/7919569379547134956", + "5/3045955944586853795", + "3/5590644354574892743", + "2/4716458025379930912", + "4/8038283893251722914", + "9/5512954029973423969", + "9/6927060051459516639", + "3/8321747560342611723", + "8/2197912678027047158", + "5/7586987102032621255", + "8/4869045686145996908", + "9/4433444549448109739", + "6/3685293734899128636", + "8/3759529146535252728", + "6/8764490087302036036", + "9/5160306151571635179", + "3/1083942653445881133", + "3/8470551824465407273", + "4/5788803375878961174", + "5/9062467445989057455", + "8/4409989560310217298", + "2/6404472160994339352", + "9/6832717908482133939", + "0/6759381539548729630", + "0/2238094377305052200", + "5/6229409522288202625", + "7/4510147468132558197", + "9/117777098499264349", + "1/1334415744942041991", + "5/4671615810908508955", + "7/5278765065231831707", + "2/5833252606831232972", + "0/4406314807761597730", + "3/2555831198558339623", + "9/1530948816448439309", + "8/6730048591377040618", + "9/8549389165640788179", + "2/8348149238527838632", + "4/9016035577704388064", + "0/222544134601175410", + "1/597507624289805441", + "8/1366297765560382718", + "4/6799092633183654054", + "7/3377108819263729397", + "8/3177092199202204518", + "1/3379232584746217331", + "3/1325522680487951313", + "8/2862733147454051588", + "0/2347715675809234280", + "0/7964537240511049200", + "4/6207668338338151524", + "4/6812674063388595014", + "4/7837685247324065434", + "7/5701147008007600227", + "6/450919967217141216", + "2/5076950223217633542", + "0/2508234660084234440", + "1/8638504913078154741", + "6/5011535764644896316", + "7/6542576065235716417", + "7/8933779209565067657", + "3/4757205849373055133", + "3/918764041032041303", + "1/4128941077358431921", + "9/5746624192744528749", + "8/715135686113669208", + "8/9123280945660842738", + "3/2761488975172641983", + "3/80909814091284453", + "6/3445354537696726686", + "2/6939247859090796182", + "6/622745317296267306", + "9/6074720817880217359", + "3/1407703364790149943", + "9/5126903676147802809", + "7/2904921391784086537", + "0/4635225194845560580", + "6/6333746907387543546", + "9/6356630319455941779", + "0/2968652229362921690", + "8/3844464787740593128", + "4/2552064145181687594", + "2/4311611041152127582", + "2/7699202260747681132", + "9/8750087166693811549", + "5/4589607816097186425", + "5/5695319396826737605", + "7/8667877828931826167", + "9/8007095045917878399", + "3/1369914056022121163", + "1/5003222550217900371", + "4/3521021506644922784", + "8/8223456394410798948", + "4/3680662438813906974", + "0/2652213809264719780", + "6/7470496452807335856", + "7/7938062576953575867", + "2/8960934886016346812", + "5/5420385329273215325", + "5/3197905070799932465", + "5/8827012754070393165", + "2/302047413754460742", + "0/2520476785781178690", + "0/7834382985940473150", + "1/8667387814281009961", + "4/8050724852375777084", + "3/4398172501722365653", + "9/6917103880575256109", + "9/7667244585221535089", + "1/5267509254030836641", + "5/8061760342306701665", + "7/3728049595016557407", + "9/6820720662810498239", + "4/4741223103160347494", + "8/7592574172337058998", + "5/4995653884607406245", + "0/7005809016668373520", + "4/1349537507845739624", + "6/6319295738627886996", + "6/1374591901512448146", + "3/3128235810154971173", + "9/8999095179143508509", + "3/6150371710003748393", + "5/3945966690093149935", + "8/6580735308048940088", + "8/791001332265995348", + "2/6007709302915143982", + "1/6977569478698903771", + "5/8379256994416305155", + "4/1546107844448979564", + "3/1824506582224391683", + "7/7737001284091166697", + "8/8872332432302771598", + "7/4968635339374860587", + "5/8833365149755032315", + "2/2345928322548324602", + "0/4367116093252590900", + "1/5953048107225770461", + "6/2581852152121964046", + "1/7896341740657392791", + "1/5429277176836591091", + "0/2862951959508264140", + "5/4749613195386014715", + "5/5599113232672766225", + "7/818752746052238207", + "1/1317327301510216991", + "0/1882103864395358240", + "0/1017021745350671390", + "5/6247354933617509945", + "0/6224608419105700790", + "6/6286500072838580336", + "9/2434995667827788929", + "9/4069996210497248869", + "0/4114204810485978360", + "6/4524241490148960186", + "2/8855572912998233712", + "7/7774077929183155547", + "1/7639142226016334341", + "4/3875081190657779314", + "0/8131136156270507610", + "0/6009414969589134530", + "5/4624099000467670845", + "6/4663571635776195086", + "0/1615005819308461230", + "6/6349078218383390206", + "0/621499897327414110", + "4/7064374131149740644", + "9/6469701384556987269", + "5/761665724789005725", + "8/4938592158349443068", + "9/5192204147341076129", + "8/1651901952458368368", + "2/7330051026858447472", + "0/4977032098320876860", + "1/3655803486599892611", + "5/1425032261294080525", + "3/6611487682540555093", + "4/1391932049475658724", + "9/6009489869481368409", + "7/4026683158242139497", + "7/7501748966177538487", + "4/8221630384582996054", + "5/9204207201610950985", + "0/4976451534922498230", + "7/3433371201270370687", + "7/7815363695164318347", + "8/8984355044722952928", + "6/2680722615791179806", + "0/1466759221101933450", + "4/3341591471313244634", + "7/6729269915263917227", + "6/7231168882304242366", + "7/1923250641175162777", + "0/4481063790707731800", + "1/4765494391713542391", + "6/5321917342997921586", + "9/2408451872914130979", + "6/6990253947321981176", + "9/1076992838481184479", + "3/140974880581282593", + "5/3823533774485655495", + "8/1063343824739223298", + "2/7823123409758017532", + "5/7997871360751914065", + "8/5374516558041685688", + "2/470377292550674002", + "8/2307455255952507028", + "5/2678050735041751995", + "1/7959271680021153511", + "1/5966738718842591081", + "8/1772448552849714338", + "5/8597646199424299005", + "5/1371378583840454755", + "7/1892220514589385077", + "8/4906660621567681708", + "9/6490239087202611019", + "9/1734481774004856019", + "8/1293384747902613118", + "6/2151301141010438766", + "2/2584323255590418022", + "1/1466421332583445741", + "3/5188126800366559893", + "6/8699002186614902716", + "6/21826831714383296", + "8/7702092315765819818", + "4/6686080820068702454", + "4/4070878446179482314", + "1/2081883513310721431", + "3/2810876996739457903", + "0/1816348373980183560", + "3/6234351309654358163", + "9/8838358539124580539", + "2/429996508147870862", + "7/4660798831677333657", + "2/7377475007847560952", + "1/4830589482300423391", + "3/8447104037050729553", + "6/3902640679831971206", + "6/5512343605639197056", + "2/4531395430432983332", + "7/8222755343056242897", + "3/5215282993262951523", + "4/1587704944225966684", + "7/6757090271872729717", + "8/833330658022661728", + "2/7034310279574669932", + "3/3137080747118889373", + "1/3102020206516924581", + "1/7457003922721937361", + "7/474416994198207227", + "8/8130745412836701068", + "3/3953028511661461823", + "0/1897593842887229610", + "0/2306978450292449160", + "0/5297094039762368800", + "3/7320271413005562273", + "9/7833737012191752689", + "9/2263411577764979949", + "3/8885468892429666393", + "7/2093364225666675987", + "9/2821465857820308169", + "2/7581210836348575062", + "4/4160101502777068894", + "4/6195684047124932544", + "9/6500634228249814159", + "7/1698741256946276147", + "6/7619159429276828526", + "0/275620893320860410", + "7/5863057186135195827", + "1/6339757183378506081", + "9/123296214083016529", + "4/3509282224356500094", + "4/6470534091364198944", + "8/593971159657296268", + "7/6473542910326088207", + "8/6092078785125223958", + "1/6402837517484255071", + "2/6152540728628574462", + "5/8335464652153150495", + "0/6678151772290067940", + "7/2229542695306025027", + "4/7936565217882089604", + "7/6199356245641388157", + "5/4496039087318301855", + "8/3657790328610256368", + "6/1968631680102482306", + "0/3791246211154956940", + "9/956369037678264359", + "2/2153498611513323092", + "0/3163913378227788910", + "6/3847273436811209566", + "8/317726365959478608", + "9/2369464440056077999", + "9/6511329651952033839", + "9/7766507686428391569", + "7/2004278986405498067", + "0/6631964183349210080", + "8/4032332788072319528", + "9/8954273360331114189", + "6/985984326449194236", + "2/3995065093935185032", + "5/476789394996073735", + "8/81611400208548898", + "3/5256550726657819743", + "8/5300549529267373848", + "1/8219487927260798591", + "6/5393743118638888136", + "9/1681402517911022299", + "7/7667661997931398637", + "1/2174456000063381611", + "5/5590604547442913925", + "1/20984558296654921", + "6/4271005433842188746", + "6/8814540461214173776", + "4/4683670610787754594", + "3/4025750153293699633", + "2/914147414606380572", + "4/5865729498576398634", + "6/6635034257065114726", + "6/7371852594857690886", + "6/6865575105453022086", + "5/6803328414710613605", + "7/1526564023225479287", + "5/8366066829471328965", + "9/9161897743014806519", + "5/3257753482468232255", + "3/4010021021007898943", + "4/4008841001091843764", + "5/2467583950289246235", + "9/1720278419627686979", + "0/2489124146515489690", + "5/4870114750986637435", + "9/5887103827260084899", + "9/6444427245368633619", + "2/6850718103922746902", + "4/4360249910110062324", + "1/6493894998064643491", + "9/3247658004670116739", + "8/5228699877257729388", + "1/8051169398293962141", + "5/7630502291956530405", + "3/3271268203063937483", + "7/3642443620398288367", + "4/6129499116745080084", + "9/6739445723505190069", + "1/8558439142916145201", + "5/4674164677548271235", + "9/5595245694239118319", + "7/7997694249713633277", + "5/229652030789126635", + "1/8129970900981194211", + "4/340658149770687474", + "2/3671578795445943252", + "0/9052546920784039950", + "2/1927706355945175652", + "1/7682349838595676821", + "4/1655080296349873884", + "7/3151228324753359507", + "8/1226144768574047248", + "1/8564887055426041931", + "2/9133188172751957182", + "3/1899018861154575993", + "5/5798448140922008595", + "5/6013310442616051095", + "0/1491719411566264480", + "3/859312768382311053", + "8/2672746695514051868", + "5/867843829217113035", + "2/1589834932341387892", + "0/8420882445470266330", + "2/4630330657053452082", + "4/2722779624643507874", + "7/2307213097839446697", + "9/8067510325157091889", + "2/3873529460151158622", + "5/4314691920884324515", + "2/2030298728726997092", + "0/6270696799092075950", + "1/282440679993594731", + "9/8498477674672169079", + "9/5781735549399363329", + "2/2264521080704168162", + "0/5191804369929007450", + "8/7950489900897646258", + "3/3825024049136864473", + "1/1834816131536437651", + "7/3886869359265091287", + "9/3954086578944108849", + "5/6713532193768682045", + "4/8562329986564003774", + "7/7375949354982654677", + "2/5751150692719069372", + "0/4157320968071852770", + "1/1156902119252125111", + "2/7099447577383883732", + "1/7961499680365681961", + "3/482996986693337123", + "1/6372963420671236131", + "9/1371534086984075029", + "9/459379352356272409", + "4/6378241090634625114", + "3/4820305073989208443", + "9/4085850685372060089", + "1/5382244433714325301", + "4/5626535682592523064", + "6/5786699577394726196", + "4/8750902759962388564", + "1/8880451771284798751", + "8/9170088992667355668", + "5/2270884242471071455", + "2/157320037785394032", + "5/470678796273864415", + "8/3799597283482350638", + "7/4974274915383238797", + "5/5373698935132926015", + "8/377486841553210748", + "9/735291709382166699", + "8/2589506481908485978", + "5/7042015294806037725", + "4/1205413566459408494", + "2/7489011808819635032", + "1/5001662295323857701", + "3/7381112376272692163", + "7/5546325830869405807", + "1/1148793480597679961", + "4/9007047962384581134", + "3/3381401591857550103", + "7/5780251823512086537", + "6/5361124813891497756", + "7/2759709861084243267", + "1/8946350928734757961", + "4/3625952001291203484", + "0/8063003960542121020", + "9/1745676135786163249", + "6/5437280065961120686", + "0/8164433132905421740", + "3/7704973639664910863", + "1/1890319493389188691", + "0/301023819188052630", + "3/5674501942980006443", + "8/6291436032088728648", + "9/6574573864696547059", + "2/8529991038017473402", + "0/7742531345113563550", + "3/3506940993491789363", + "8/1279871264770370418", + "9/4863043558312547709", + "6/4249199382464924246", + "4/55246016422049914", + "1/3521114955404685481", + "1/4389968909160434551", + "4/3392798036397316754", + "5/6046379425849650405", + "5/1430895048555521655", + "7/4019596718346047977", + "7/4324193076187289747", + "5/518860287169127175", + "5/1362572755966468585", + "8/1017716286010890988", + "3/1449516775148799433", + "6/3402089656895182196", + "8/2945596131756432728", + "5/1833026716464781115", + "9/6533521489188261699", + "5/8584339143354865725", + "9/2350215742738557569", + "3/2464546053786510023", + "7/6510808506330705947", + "0/135557450650410820", + "8/4049944153970271888", + "7/7005191581754761707", + "4/7905733517093819894", + "1/1174948808141461131", + "1/3614770833222110601", + "9/4852777909839866019", + "4/7822916201794807644", + "7/8446184260491346297", + "0/3901283098535947590", + "2/7650794219219239462", + "5/692384020234848315", + "5/2849083974100023855", + "6/6696353975013327516", + "7/4159696367168820317", + "6/6756989005487535896", + "6/4452327144281692676", + "9/5428435366353072239", + "2/7347553663374376722", + "7/6325005767483165357", + "8/4190592548114932278", + "6/1894718421104992956", + "4/6960825562849490764", + "7/7047354958978111927", + "9/7393861748125486259", + "2/964373624003688272", + "3/3968705977490175983", + "7/4410052816285592437", + "6/3908554432380034446", + "5/339643359642895655", + "8/2099389318759748048", + "3/7308727983420959583", + "9/1664616306792400159", + "2/4318492895490716572", + "6/4955930409373434926", + "5/1624057849439889525", + "0/425275497810314700", + "5/982927096979398065", + "5/3799024454922510805", + "2/5092038375623347002", + "4/7473374001041928634", + "9/6605368174150174879", + "9/731891226246911119", + "6/311727813658093946", + "4/2650461087797544844", + "5/5088096696560320575", + "7/7138281978714887427", + "5/8450206288645858595", + "0/1289282855879821520", + "3/1997816719853729773", + "2/3769062481440933942", + "9/4276554065118619079", + "5/4416802282392883205", + "8/5807368689223620618", + "2/4619815427349423792", + "7/6490462980473854597", + "7/7290852097851079757", + "6/7298115236073708996", + "8/9181728945980077618", + "2/2518404641225278252", + "6/5274294446064976956", + "8/8761205978362775348", + "0/1395988618840518040", + "5/8770003043983578215", + "1/1138706896168470911", + "5/2520597946423982475", + "8/4762507846556784258", + "7/2580260280924010857", + "9/4816434013909602719", + "1/2126161250070284911", + "7/1936948070679979177", + "8/6386738740884188698", + "2/5001054282873561112", + "7/1899422642160338057", + "6/9180109282165231586", + "8/4366545887853130838", + "1/5567831571387242381", + "4/1513644561038714974", + "0/593023258276062770", + "4/3114193734050789944", + "9/3274964875961114719", + "5/6002552795707075255", + "9/225702225512201589", + "3/2951183876194906973", + "5/1485634855673401225", + "2/6151963903067616952", + "6/5540536599977902196", + "5/2312978507499910925", + "9/9193656389589483759", + "8/5279007531241349118", + "0/2715706214097865480", + "8/6021468983402439698", + "2/2672943271484762692", + "3/5511450783028534503", + "1/8323812488188594091", + "4/8480189543607908454", + "6/9129384844228822436", + "2/2833667053171616932", + "5/4867589345110436595", + "5/4509375547150959905", + "2/2010507666371756662", + "3/789956380904311773", + "2/3777938291893109322", + "1/2262201684264747971", + "3/73099722059053383", + "7/116551390766659957", + "3/1498080299391959083", + "3/4488819925343164603", + "9/9005273699136873769", + "5/5889790854426331235", + "9/8890303198030004419", + "1/8219159606638557291", + "3/6060953150171867653", + "6/985246214621322536", + "0/4468330755173788110", + "2/7965890699504203532", + "4/25771922540830954", + "8/5063390371627324658", + "7/5859612993637691747", + "0/5203575998142204530", + "3/8876755361284924753", + "0/4085934708470583430", + "9/3600249924638679269", + "5/446988878726045345", + "6/2284701144069430956", + "2/1029456057877656872", + "0/8209741654252285820", + "0/3583054622791744620", + "6/2587463152970642556", + "9/2690471104921673109", + "2/262629806074294522", + "0/4262265738391622890", + "1/6219090451102052561", + "4/6851111333736244974", + "6/5596407987647358276", + "8/4020337989988111318", + "0/2755690325740406830", + "7/6741155493311103507", + "6/3432617700816226766", + "1/3444209959445967531", + "4/5377799571353755354", + "0/2774435758019212500", + "3/589978576266576093", + "6/7886251090030066156", + "0/4879987359729999510", + "7/6411624874916198657", + "7/6143391051195254107", + "8/6479454931835653298", + "6/6816368293177389686", + "1/2044031345133894281", + "7/2917840701510878137", + "0/5240671838914549160", + "9/3381530027960982579", + "3/7786609560119418353", + "2/8226091555377232402", + "3/637134198387973933", + "9/3511761410235620499", + "8/1199588004763381198", + "8/1707760147131492118", + "5/8491815553709533945", + "2/4212829917140869472", + "0/2007819380846615810", + "8/397784472120802768", + "8/5589991693873249998", + "3/8863968327252214133", + "2/6947059166350753362", + "2/7992454989454650172", + "7/2349037199553766707", + "9/5603272676444218179", + "0/5203681443154835340", + "2/3908043431880167762", + "4/5617664353522508624", + "1/1041853674271011141", + "2/1803640178218147762", + "7/7588946339824989967", + "7/9164663482494439277", + "9/7893924992013224669", + "1/6768923994487851071", + "1/6680089461718249571", + "2/2305630454190629312", + "3/414241195814877043", + "3/3998395869537088143", + "3/5876263819924843093", + "5/6044347521251892895", + "1/7493408466871392651", + "5/5894921802924197965", + "0/6214995593725117160", + "8/6411658993168189228", + "2/7548293030159864052", + "5/6223439312864253055", + "1/7237617259685988831", + "3/1779760268622139493", + "5/6080506138898379785", + "6/8547562622601267356", + "9/1346846622355491529", + "6/3694348347874932606", + "0/5304539707372814480", + "7/7455617653647067007", + "8/6307444933235002528", + "9/6412954704429879299", + "7/7020673671586169247", + "6/3394441140945117586", + "0/9166814195533406740", + "7/2376664136151863577", + "0/4558791126284353710", + "2/7710318626975760872", + "4/4195354407464970704", + "8/8619505960979958738", + "3/8601983195571599433", + "2/1327271097088794172", + "2/1568987476719995942", + "3/6112638883590440813", + "4/7769902201675501694", + "0/2069296163614658250", + "5/64852045132622385", + "9/7681553356857362279", + "0/4821194456777886400", + "9/8412264407876547229", + "0/33048841234194340", + "8/6119708458983880218", + "0/5488958221414592490", + "2/7473598807397982882", + "1/1565718555005076891", + "5/8380074095465211425", + "2/2812488015917186492", + "8/6095211379638063858", + "1/7773810287406766131", + "5/4182225816907858655", + "5/90120771761696215", + "5/4523496263931617465", + "9/4495598922399792229", + "0/6713835726895453700", + "6/3904266484095917636", + "1/3074604249630464751", + "4/8965265124728465474", + "8/7848722976983791098", + "1/5568052759582744211", + "9/8758366330385484899", + "5/2755659135629467925", + "1/1932060772942879071", + "8/5669019586590698018", + "3/292643614995841973", + "2/3927338852366760712", + "3/1207691013861805043", + "4/3710040273909006714", + "7/83826697841366767", + "6/7053893875834995946", + "5/5329294145616654275", + "4/2255784145559324234", + "5/1531996010920011825", + "5/4113139036258175045", + "0/8472884992962671160", + "5/2174152586543233995", + "8/654353385918019058", + "9/6430821861136711529", + "9/7525609065198432769", + "8/5592626692698085958", + "4/9118259336314341524", + "7/3752548468959819547", + "7/1633956533102500607", + "5/2130899643071046705", + "5/597529654768744185", + "1/3131328431366983771", + "2/1697920198043380302", + "5/6274072632780357635", + "9/9086581799830575199", + "4/1751998635389334054", + "6/9042426227603016526", + "7/9042128895093871197", + "4/2899819526027946694", + "3/218005040403873133", + "7/2143808767310865827", + "7/3495096849149341707", + "5/8163234756188890285", + "8/5688349728733765888", + "2/3596930480760905022", + "2/7073692246945493392", + "0/685149046101294740", + "5/5161440608496640785", + "6/4476097035800555346", + "7/6151103974907852977", + "7/5507239128513742447", + "3/1509675832612995713", + "4/6042466420228011854", + "9/5394726370644826809", + "0/7317852699862906020", + "3/7744701774078676103", + "5/2100463417634898805", + "6/8276605251488514916", + "7/4095439911000978827", + "7/684551812439008217", + "0/3941519388352783050", + "9/467834910859614129", + "8/7647434750928751408", + "8/8534056118625768508", + "2/8068603115776471912", + "4/3349166829994199294", + "5/1416484814756091075", + "3/4432162651735963343", + "5/6780165406278549345", + "3/412016789119623153", + "3/4649484666968089693", + "2/5453440707808618372", + "5/8767823553759206745", + "6/6460593563660743426", + "9/107159197363605129", + "2/9138892883072488102", + "5/5665145748607027295", + "4/8658238306774381944", + "4/2849169045180588534", + "7/8134923513882386647", + "9/1484670333578476789", + "3/7339126421327028703", + "0/344886352389488430", + "1/5822291936072997961", + "8/3696864941379154988", + "8/1922441646817376398", + "2/3067114970997452802", + "9/6284781652955582979", + "5/8092930406519672945", + "9/3093771033463504849", + "6/1488091001178428916", + "1/1363009620897505331", + "5/2974642242888295055", + "3/537531113514541613", + "5/3524153911185507145", + "5/437960984049918195", + "5/6675635529123282065", + "2/1976769686691690332", + "3/781427084191255913", + "2/1366619909189127022", + "6/1427770711283377446", + "7/7356188740043311457", + "4/1042278195394219184", + "9/4078027029369220839", + "9/7800408407249314979", + "9/3329250477867917969", + "4/1046789658494952804", + "9/8968271120388272739", + "4/1796131030261794404", + "6/89168195241815716", + "5/210460385689634815", + "5/759107404608270405", + "5/2490118520048524035", + "9/6763549052824421279", + "9/7654635708562850709", + "7/3759582051582066307", + "9/2397236509823550469", + "6/1571153729315444716", + "4/4595133630444475514", + "7/4318412647287407147", + "6/8161679915401412656", + "8/5562286250762272668", + "6/6322418878055903416", + "1/5814053702244186801", + "6/3159016132753271156", + "7/799043638365509917", + "4/7843695037835084644", + "1/7800529072888416891", + "3/2278208686153487343", + "7/342962167228110697", + "1/5708619833618961551", + "3/3968615886162859413", + "9/5835259913109691409", + "3/4828224633843910283", + "3/605306139349143413", + "3/7951958384420395143", + "9/8499476164862013639", + "9/3578281324842690389", + "9/8591369021898539229", + "5/2915293962346913645", + "0/373008098193672690", + "4/5574301519205370824", + "5/4051647757761339555", + "5/1189318216410191025", + "2/148797622024211842", + "8/5180188567359432138", + "0/9013872449060451940", + "0/6901742904912996590", + "1/7296371252904137891", + "0/3090679104236074920", + "2/8547328479074703022", + "9/1820298696775938079", + "2/2164704843891588582", + "7/5672696429050008407", + "6/497773801380415166", + "7/9162051045522528807", + "0/1949901099817722470", + "4/4237153073761085164", + "5/6248852895077709105", + "9/4480219771140812809", + "9/1751813459325218689", + "1/4860165596960510911", + "0/1710372546285420890", + "5/4774675222564012115", + "9/5897740523689751139", + "3/3309650901715165803", + "2/8418679681742921362", + "8/3625063164808415778", + "6/6303712106818905926", + "9/4001634799026037069", + "6/8170493178645279746", + "1/8535630064298840571", + "0/3109556741121735040", + "8/8742714472296923848", + "1/5540169466865010271", + "6/6733672951862564496", + "0/6093838176315624640", + "2/5932644430724991692", + "3/3655534719725696643", + "8/5187751485593392268", + "8/9079213096429984038", + "8/1289103287426177128", + "3/5586366413714245953", + "6/5897852058985167196", + "0/7511276852238643170", + "5/8116519396857616215", + "9/4308099257912559459", + "5/6403747211209480965", + "7/7361324012686509357", + "6/7693975855780853656", + "4/5453293767946364084", + "2/2659064432211870332", + "1/8620004109523672041", + "2/6477773639857674732", + "7/7570429144982444097", + "5/4148778170095081285", + "8/4533847059684771758", + "3/5046740926769088023", + "5/2346995220203179545", + "0/7945236497123320080", + "5/667791317771329205", + "0/1470190039956376850", + "0/1911331857931219260", + "5/7818561530426580425", + "5/5326676257177217505", + "6/7244520757667942876", + "3/5315041727230031823", + "7/1262587078349618227", + "5/7462328061631854415", + "5/890820053711627425", + "9/8447133828996388379", + "5/5337142539261364925", + "1/1417312596849664151", + "5/8175878451867223535", + "1/1274139634006670401", + "8/2336590712998739528", + "4/3555705602455458464", + "3/2467525599647775873", + "1/6000746542596007531", + "8/6466931670605054408", + "9/1589703685907990879", + "9/6019430529702607099", + "1/2935463770938550931", + "2/6556395144492779672", + "5/8036787228130299285", + "0/4958885335207550500", + "0/520988452048630510", + "7/6190181945259996517", + "9/8725098017995600159", + "9/4795470930576797489", + "3/4398465440364352073", + "5/5886001424523116195", + "1/7370345618386181561", + "3/8528007167651571043", + "5/5074367596293226965", + "8/2075483669720335088", + "5/2230129084637524025", + "5/6920008806731794225", + "1/7068548199236921251", + "1/7731701511509049451", + "4/4213529107951841794", + "7/4688455293553688987", + "5/8365559781461634015", + "2/4810680866246101072", + "4/392894881385542264", + "2/4566331182280296972", + "2/651844743281046872", + "0/8938511937069759400", + "6/10744637720326666", + "9/4184381557527457329", + "6/1744959275088805896", + "2/1016416921489958162", + "6/6026710825927474516", + "6/7125978625810335916", + "4/1061622931103500424", + "4/6170480660003122104", + "3/8252747518206531743", + "5/2858040366096898295", + "0/6614258433821636320", + "0/3891042114173144140", + "2/2091238985591589522", + "4/5338804783071617064", + "4/5963352016242798364", + "3/5463567720269461373", + "6/1403387275242204196", + "9/523186979494494299", + "5/360303607017486045", + "0/43875995480588670", + "4/7920169077624115284", + "2/2146241334548749482", + "5/9092711865339913795", + "9/3658165668749416399", + "7/2423321845852862277", + "8/2133712895824636198", + "8/2680868708875179508", + "9/8498407038297907379", + "6/2984932497670790766", + "0/3404980665746001490", + "0/3503339509514062420", + "9/5944012791188098569", + "7/6562906353188343827", + "7/3558775747782382207", + "9/7027922861514056549", + "7/7402705026771441967", + "7/7607598072893454047", + "9/5689120133376901809", + "1/3280523630200483021", + "5/9016392239316538335", + "8/8924403548796070158", + "0/4555885783976703230", + "4/5031587324493771954", + "9/3185462752780939889", + "1/3710664550194087831", + "3/8695121673856494473", + "3/3369352411054124373", + "9/1401738047360271369", + "4/4385923456947277744", + "5/8114292610951075845", + "8/5735580189128597838", + "6/6505004857248787806", + "9/4284529817169086779", + "7/5469011199918495207", + "2/207505902300160552", + "4/12830497265809044", + "4/8540266539640206894", + "9/720298160147710409", + "0/6272132803571564230", + "3/2943232830161000513", + "5/2998766278220826825", + "4/4558620593782812464", + "5/8486591619467592035", + "2/5482939411046888152", + "1/6241784698099368951", + "0/9172562088318328350", + "5/643611346098972825", + "3/3371486272268275843", + "6/7283769103060777216", + "8/5133112617334712308", + "6/7048972180978616566", + "0/4791819849151253660", + "5/4830008817613427455", + "9/3654010016250364279", + "5/221608166523810535", + "4/1857261798769392764", + "2/5365192667708659942", + "3/2235145095421944693", + "9/6127366497881686269", + "8/7076570068834959678", + "5/6267293399228752365", + "8/4469204748379395308", + "9/5197231827379079039", + "7/5516208387877604867", + "8/8412491714525524738", + "1/5530703637612252551", + "3/640113050279950663", + "4/46755153835406754", + "1/7640373476021823171", + "0/7649044289384916780", + "2/4592051014091224802", + "5/3644572816972930805", + "6/2912496408357913056", + "2/6088675063298978402", + "6/1120461247482241426", + "0/5455086490890227020", + "2/61852856952798292", + "3/6050439572437625103", + "9/2568998483383328319", + "6/2265053280747253236", + "7/3199511126967243017", + "7/7542127578488699337", + "6/8003080930644762516", + "1/145488741750073301", + "2/6325573534929159502", + "8/4715954907112964978", + "8/7757470206329761558", + "9/3154016328351186819", + "9/5150380095657899129", + "5/7129540942991955085", + "0/3030608891457517750", + "9/3286561643969338839", + "0/6680613593330678370", + "2/8920222163079045132", + "0/7308063562971853000", + "7/2198882208758788697", + "4/2205849978101920654", + "5/1915206863605013295", + "0/9191712985003292040", + "5/7287924915556891965", + "9/6773462310258393449", + "4/1138868301081054624", + "0/2473246421774088340", + "4/4618633675481632864", + "2/6735079802213093012", + "2/2557050557653955342", + "4/6794490089473741204", + "6/5052716695001173336", + "8/6344551966740052598", + "3/557264301075243263", + "0/7904610316889843040", + "3/2458033189670351553", + "3/6971904574873301313", + "6/4184466857926566516", + "6/7373685038346079536", + "9/1266416323531831469", + "1/9041816378508088331", + "6/5094534464514562616", + "1/800070265771920491", + "3/1879334081574912473", + "5/686460618876476735", + "2/2478342953142755052", + "3/4189012509709677753", + "5/2292111146084781305", + "5/6717179868275232935", + "5/4928561460504528645", + "0/4341133159010511280", + "7/8648278707246779527", + "8/2889878131525179738", + "9/1761384638556245129", + "2/1632031105392213582", + "7/6316114403202937187", + "0/2446809074413740290", + "9/4153634331381069509", + "4/8263081253484182754", + "7/2472744627019609537", + "8/97404274596464218", + "8/1661499371409158538", + "4/6693494196958487814", + "6/2964402604939468116", + "4/175767267243538204", + "6/5828479961209543596", + "4/8949864710154562914", + "7/7014944017912004267", + "3/6898324111999029403", + "8/9105334590088382148", + "8/3592426003429694608", + "9/353726316612948969", + "9/7032224605398418099", + "8/5832083475937061528", + "2/7125875694227913402", + "7/1849820359237822007", + "1/2454551453314083181", + "9/1265592699132249759", + "1/7704712720632175781", + "3/7641348298964919643", + "6/4863917385678879866", + "2/2412750136743804052", + "7/7202755555797499207", + "0/7611927743894235780", + "8/2363750115077791708", + "8/8073217755369504348", + "0/5125149248818876190", + "7/5445260594625275647", + "4/1794613792366943544", + "0/3428672211765944980", + "0/7126555487613237360", + "4/1352673810567437554", + "4/7146174902603061914", + "8/2141639507221299308", + "3/6756846697426707023", + "4/5725887755534211614", + "1/6878356986707439651", + "4/7989122072718486224", + "8/5965742016215454008", + "4/2670772403062244114", + "7/4062943254183654417", + "8/6761880786506740538", + "9/831864008279155099", + "8/3372516693732555218", + "1/3951908746165693381", + "7/6208415232804548317", + "1/7353607180342148071", + "6/2428090947414281566", + "5/4585761544093421015", + "9/4037566544278656929", + "9/4249531783312877739", + "2/639622968775020772", + "3/2395768502412441903", + "5/2632870537209253875", + "1/8022815948228329611", + "1/4138909813767536191", + "6/3402432534578691936", + "5/5490091509738753255", + "2/4353627768647147952", + "7/629436963644692517", + "4/3162101059639200804", + "0/9053862980717579070", + "0/8327971414599492970", + "7/3227002958768452717", + "0/8800669539620296450", + "2/1458591398979287222", + "1/3070826662851447321", + "1/8514688532530381781", + "6/1766294624634476436", + "7/4329043051113539287", + "8/3665019043619948698", + "3/7591264050286500953", + "6/140887941254692236", + "6/5931944579596648336", + "1/2482201632656170801", + "8/2736464792320675368", + "5/6452041658041315345", + "5/7983081621167601915", + "2/8817152641035308632", + "9/7467281040701913009", + "0/4435053719743341910", + "6/5247694181004586396", + "6/6954421178887928816", + "8/7029230781909160288", + "9/1763609401826319479", + "3/5615189115551923983", + "8/8914322533837160148", + "9/5299143154569717629", + "5/8085687289276693775", + "7/1770348186864103777", + "9/7707268896711362139", + "4/8023336547507900284", + "9/8711686161228231589", + "2/4191699909477366082", + "3/341644435446062733", + "3/3423789196171610633", + "2/4299184539624499842", + "4/1276298857244269084", + "9/8012649552505001799", + "6/3338381901115761506", + "1/3385325415286775371", + "1/3388793821063471031", + "2/1828204423718032212", + "3/6630089657687973623", + "5/4857196828898112645", + "7/684311908539665847", + "2/284395211979364962", + "4/1502505991704265184", + "1/7561330732510734791", + "5/3530536181961828365", + "2/2950858215047039252", + "8/3003736060982045598", + "2/1275510636278288782", + "3/2294407606878338493", + "2/8913108787414020132", + "7/2400804745276328937", + "0/1393804178945326420", + "1/2555443093025822871", + "5/6886620716630266855", + "6/242371603155156416", + "2/2868328380021340332", + "8/1505780691064179408", + "6/6438986087468871646", + "5/2882735265280880535", + "9/7664114458719640819", + "6/3766577164998336116", + "7/6392697275435171767", + "4/8827244949570854374", + "7/4082413921410384247", + "5/4084114146645434415", + "0/2579403339481054870", + "6/4635007052442092436", + "1/4423811308866478461", + "2/4700177343135736122", + "8/2781930657843593968", + "5/7451281098734536935", + "9/8708925985134121999", + "5/6662460861700787715", + "7/1871916430829609597", + "8/2254371462248369778", + "0/4453971721652566390", + "6/3366174881694493176", + "6/7833798746118638876", + "9/2101291571762929209", + "9/255621527381431899", + "7/7788623510514067337", + "2/2397237866129775872", + "0/1782961049457770190", + "7/3519944634925339307", + "0/3337424531654674930", + "8/7422276524388616068", + "4/17802834865781484", + "7/5720202509135930407", + "3/3932895462885304883", + "9/4126173016792083069", + "4/3204793096375165354", + "8/4651164966286145588", + "8/2164134848037965498", + "3/603818408386122723", + "0/3271961944696653520", + "4/7761482439804825084", + "9/8913801317022540449", + "8/5410727513068251748", + "5/8989174220291042555", + "1/4530523663761343941", + "5/6911901902541332655", + "2/7308359517569840492", + "9/5609803140790616399", + "9/4473221108536689549", + "9/8031247854761548599", + "2/1595910563922892352", + "0/6645719916538105920", + "0/6171884197754545830", + "1/2893402068216653661", + "9/7614063518634531379", + "7/8109104828854878677", + "9/3630542777305053449", + "5/1032110111010302715", + "0/2732454293395088430", + "8/3287297655757058958", + "1/7514566945913796181", + "3/6275441891015832813", + "1/3122953928152630221", + "9/1331817435458198659", + "0/4401826936623794800", + "7/4124815132035218877", + "6/1379457394879106336", + "3/6342263473313804313", + "7/8592515091615325387", + "5/1036264542944812295", + "1/5799924575815671221", + "3/5886478997708374803", + "8/5826517778843332178", + "3/4035071186873307563", + "5/1352425100803662735", + "5/8144010844195173505", + "6/4695993533900089306", + "9/1759574263845061349", + "5/1381342713421967215", + "9/6264982387288645789", + "0/2672128826924130430", + "5/7176687355206891435", + "0/3364006844696992760", + "1/7381036935501099981", + "6/170693047778103976", + "9/5559519483549292979", + "1/5682925015061049581", + "5/486846451211378695", + "5/6980398976781047045", + "7/4623429262443406327", + "9/2168586121674043459", + "0/6869834906500348700", + "7/5495204374094130667", + "0/3251365818533251360", + "1/6816309353214618241", + "5/3504789543861615175", + "9/1036926518430710049", + "1/2075214691411526221", + "2/453214551745758712", + "9/4230175114409822759", + "3/6091154431638602183", + "8/1016754907522903578", + "1/4651164363694271281", + "0/6041489457709942910", + "8/1672666972742319118", + "1/8140520691255695071", + "5/3277733582142040735", + "4/6506616336949694674", + "3/4931316243541780623", + "3/6642385983595143533", + "8/949333577613857398", + "4/5233787077402556494", + "2/4487002549221038932", + "2/8429813725147554432", + "7/2344490880677818587", + "3/5761329275155874023", + "4/689551748081700004", + "1/939117560146450901", + "9/6882961254503700689", + "1/2524439817046651171", + "7/7373975276008538757", + "2/6586122400121045882", + "5/589291343601287195", + "3/3873330367378359043", + "8/7701665632501198618", + "4/2801046905754041904", + "5/576660501817893455", + "9/8200136512394883009", + "9/941922706040220789", + "8/6981521880253268528", + "9/7107774759708629279", + "1/9159670699114413141", + "2/3830381993384589492", + "0/3648171753281176060", + "0/7335329509054943460", + "0/1728707499389531490", + "0/3953598288654932620", + "3/8638923220210673863", + "3/3839354306289646853", + "5/4614073822320313225", + "7/5065110062366560957", + "6/1677764694569818336", + "1/6643538760978664541", + "6/6821142625341429196", + "5/5105391043285855395", + "3/1352260489720519113", + "1/4719442675703871761", + "5/2837921263707168355", + "5/4980358580627318965", + "6/8131081732879233856", + "8/200014170971508018", + "7/2818168321152672187", + "7/7786766806287839657", + "5/6862635803117960775", + "8/3637067613274566718", + "2/3007737362884307572", + "4/445061534495665744", + "5/1391069864507438225", + "3/6705797052278783053", + "6/1170763209544439876", + "7/1505607978791517667", + "6/6597644478840012206", + "1/1878331449195281401", + "3/5285869235358627933", + "7/1857670794208191087", + "9/4255404088271656379", + "2/4160934510205293282", + "2/7503229929586800812", + "2/8867135773505750162", + "1/7025992518027326311", + "8/6744939857532507408", + "1/7953716148271220741", + "8/5057537072278330978", + "6/713899761254060826", + "5/2611019882496156805", + "0/3971667644142091930", + "0/1262286455987910700", + "1/4101253448286310771", + "4/5890824493597526694", + "2/1715276049049876972", + "4/6303529144939809914", + "4/7291478949208214384", + "0/1641269018170177840", + "9/238680598699280629", + "5/4320332975603576195", + "8/7766938950347988308", + "0/46021539700593510", + "1/7199138574551031161", + "0/5361608075849130290", + "2/1933689848802463392", + "3/3131707966730538903", + "4/7502834911298610204", + "2/2096527722665540712", + "6/33878087036523256", + "0/77516888041695240", + "6/6198977335593520866", + "3/2388625469125024993", + "1/6056393358183577091", + "0/8657489412078789060", + "4/6183077212588878284", + "0/3192241547822472100", + "9/3063939427631769999", + "5/5042977376127980035", + "0/1896048052644820340", + "5/3258622334831250375", + "7/5399249927705826277", + "4/15888540358010344", + "9/2838747699008065559", + "2/3761774964848349652", + "0/7591624999403875810", + "9/8604692600577690619", + "1/8010530310752472991", + "1/679782888939495741", + "8/7289645613506402988", + "8/5353232485837163028", + "9/4197028962276926739", + "6/34007648270848436", + "0/2277674519836849880", + "5/5809793698340159975", + "0/9050817711818058770", + "8/5676978739956940388", + "5/3686957944631151425", + "6/5528677057420834536", + "1/1163308752979495031", + "3/3542950829566545973", + "5/3625947947262075915", + "1/7558107291350947611", + "8/7769608823420662558", + "3/6027004489203091823", + "7/4880360715505047307", + "6/2110188584910564806", + "8/6917673005171458258", + "1/6589902704031965401", + "4/4825168650098397114", + "9/4710538790972533489", + "5/6491330889944452305", + "2/3687872943960472552", + "2/785806768326406982", + "8/1116960811988302548", + "4/2711172891783950054", + "2/252721615184455762", + "0/4921158581833843320", + "3/8345689227823429413", + "5/3464506261774840275", + "5/7965216394150050475", + "8/1706240133491766838", + "3/4039956347981820983", + "6/4201556204095262406", + "0/3188261933320593080", + "7/6057511193619350007", + "9/5617706316460163659", + "2/2899043180384090432", + "6/109375421639299626", + "9/2360950554895965779", + "5/3926920049635523985", + "1/6259847094717140111", + "2/6633695253230349552", + "6/2727493829770006436", + "2/1070712620932226502", + "2/7309770683832736682", + "0/2538418299181565750", + "2/2068137275379305802", + "2/4567281008169214002", + "4/6472263278455391134", + "7/6990379337909839077", + "7/893070785770268237", + "6/6696268707551060736", + "8/2445715322350650378", + "4/8997206574462348604", + "4/7483421074227362664", + "9/6510085219918071789", + "5/2622324197671021255", + "8/8173431221810723058", + "0/4782787250385982690", + "8/4647170232164009248", + "9/8845287701576954669", + "5/691945367580485915", + "5/5129052387513716225", + "1/1391290508199654631", + "4/3561913042959010474", + "8/125014244079297128", + "5/7009851612740324385", + "8/963474662108511068", + "4/7563919421962081354", + "3/4933842745970534333", + "3/3789716009451258553", + "5/883936023426230985", + "7/708213840093053527", + "4/650702036445054014", + "6/2233468989436982896", + "1/8444874686386040441", + "6/7146286767270395956", + "4/4906250367849098394", + "2/6864464033330317872", + "7/1855619439211468227", + "1/5748884261883922761", + "6/4442502679192873556", + "5/2273170151398093785", + "0/5749908116480830630", + "5/9016641318786687975", + "5/1456861330872025215", + "8/500213020530230118", + "4/2804815387604642754", + "9/3498566662579344129", + "7/2547339687459757357", + "1/1394877893753545581", + "1/4104891824923899371", + "9/1831486786354331899", + "5/3357341861672185365", + "9/6550531553048250709", + "0/1750514182211058250", + "7/3852761595401473507", + "6/4507412002926188166", + "8/112321430092994178", + "5/3678334901717172325", + "9/390669678701120779", + "0/6937854530592091310", + "6/693142770932795816", + "1/2289026728592985181", + "7/6177649328059237517", + "7/2958432709450440277", + "2/2788760468142871482", + "9/4286035335769313929", + "6/3735966752656008036", + "4/4633688391265789314", + "1/4356189680527030641", + "3/110219058387835843", + "6/7359029356520224096", + "1/8807459480091780411", + "9/4149752435540640479", + "9/3872292788648926259", + "5/9049725031297712115", + "5/7881997837081513685", + "5/2489796629999020845", + "1/3625692226730710991", + "5/3968147054833037905", + "2/2469817862978112422", + "0/4503385273468617970", + "4/700727952898689994", + "6/1522335350928404686", + "8/6789000481058348358", + "7/4653720836717610037", + "4/7403995698479878754", + "5/1364877347659250455", + "7/2666889453747255017", + "1/8823962987574241581", + "0/8016434577036777700", + "6/8797642502024766376", + "6/4007539300648430376", + "2/1706526975328526772", + "1/2649631954845855661", + "3/617152248567372403", + "6/25268213934488176", + "7/5472983361868416817", + "5/9009092908810082255", + "3/159103332937229293", + "8/4173198896210492408", + "6/5105875374616383776", + "0/682239182678089640", + "5/5233111786253110445", + "7/7786529800596872217", + "0/4507147301381029760", + "2/617826651123151952", + "6/5361643323019515436", + "8/8285215787713840438", + "1/6470990680125824321", + "0/1665852633686692310", + "4/1958101439721141434", + "2/4255629802141595722", + "9/2303727853616949239", + "2/7477455583505008362", + "0/5084940010793953700", + "6/8209871178661902676", + "1/751089991267733661", + "3/3014883380485848453", + "9/4613371291981407259", + "3/2426197968176800013", + "4/1165576984200172944", + "0/2808043571995452300", + "8/2556808403741200968", + "9/3971568617173147729", + "6/4829002735176697986", + "0/2453675128932950270", + "2/5538248279979578662", + "0/1669388077150762170", + "7/7302362184513067237", + "8/6431640038447253158", + "3/6278230202689166613", + "7/6620267545579320647", + "5/1926676501831534275", + "7/9002963456460650467", + "1/3021120456319869041", + "4/1466163837845597044", + "7/7047146201995508827", + "0/3973872337376392490", + "9/7671101323887606279", + "8/7438405456122654338", + "3/506390293016086833", + "4/5791380148931290764", + "5/1982620512500016475", + "9/7189414508288950459", + "4/8642422916954227654", + "2/2497672133466072052", + "2/698399906986647402", + "5/674485718542311245", + "1/1117798752579688721", + "8/6712298168455058268", + "2/8544249562285768412", + "8/6208087852374245048", + "1/7867627244824616321", + "4/2418692134395744914", + "2/4823889503478064612", + "8/3611480211868994988", + "2/7093924030475020242", + "7/5981754257589428337", + "9/3708028448193293559", + "8/6963160035399234288", + "7/6714257684121571207", + "1/8487819506678569031", + "8/4196966537530031788", + "6/1310899728796162586", + "2/1867886958761800052", + "3/2889743159497651913", + "5/4407895463105380685", + "6/7150779850487619186", + "9/5277240247747908929", + "5/1929017825815731625", + "7/7418964594399838757", + "9/6701978562756703899", + "8/4958943751127060188", + "1/6104972079575688391", + "9/692916104038575439", + "6/8746547613605482976", + "4/1309136252055954994", + "7/5443531488308032057", + "9/7259642818549441349", + "1/1382855036414655511", + "7/4727478503589516527", + "8/7017230362819558468", + "2/459364084821263792", + "6/1433487141043119496", + "8/3779546080853867618", + "7/3278200331759101167", + "7/4886901751625385067", + "0/4904143627888872490", + "4/6919742514652955284", + "1/1502013558291644441", + "2/6637359894546940232", + "9/7509736811771974999", + "6/3775679874536445976", + "1/1994438583700369751", + "9/3436280237392378209", + "1/4503910505691697491", + "6/1238677996913697246", + "0/2150932568371058200", + "4/4820105348605269644", + "2/3822119815855584332", + "8/4142805237857687258", + "0/8253401080169821670", + "8/500133681525558408", + "1/1441327835679074271", + "8/662679383386430228", + "1/3934112950709188581", + "4/4581811442295991134", + "7/3447317242581549297", + "9/853880778309560229", + "9/4725795111102348899", + "3/6525912831659547523", + "7/6949968450864993507", + "4/1638694150745728154", + "2/7423577807032293392", + "6/1414479946623080426", + "8/2070707084857995078", + "7/4366238731247523767", + "2/356901273747251832", + "7/2480350416686924777", + "0/5304739315308112590", + "6/8602647944325454146", + "5/7202425598091623635", + "4/243951807016792074", + "7/2537022276936884207", + "7/999437592862772047", + "6/3555672592545044546", + "4/481296377309880514", + "2/7503889826640842862", + "3/328544282993749123", + "4/4910958657829905084", + "1/7786820513699823271", + "5/7177983679214505805", + "4/417020987538844134", + "5/3182254912469961275", + "9/6149122100239782879", + "1/2197934626166237261", + "5/2371487913986320325", + "1/4976550448988514431", + "5/9216894998717544655", + "0/4745909504560941260", + "8/8889223484820200278", + "9/700361382925430209", + "4/6840745670228932174", + "9/7901337220203594339", + "4/3740420377354097874", + "6/4790299035675027536", + "9/613924589377679399", + "8/6888919049129174358", + "6/644139548134232826", + "8/1326810739716718848", + "6/4022521510507820546", + "8/2215251443335487338", + "2/7789350238100492832", + "1/8209759506462696861", + "5/56557314996231055", + "9/8793227826750345629", + "3/7356184491694903703", + "1/4032527749643104121", + "6/5129779091165800356", + "0/8509990442839177560", + "2/1352835211892522402", + "9/3858319118170442779", + "3/9221439482251303113", + "5/4935014413338299355", + "3/5704840537841976003", + "4/3416990186926465664", + "9/4551126920496463209", + "4/918209882796929544", + "6/7325593048419508366", + "2/8067091881485824422", + "7/8161967961075360987", + "5/745895330069148035", + "4/3437947161140431644", + "4/1279738585335058094", + "3/5905822359597354003", + "4/3489436372575248734", + "4/6210658408952459514", + "4/2752376730498348014", + "8/9127951668258745108", + "3/7258262270085040463", + "0/6961412803001747180", + "1/3937078821654495111", + "9/7288341497288470849", + "7/4815118556051288177", + "1/4746155579479160561", + "7/7923544783788287267", + "5/5189349779423759915", + "6/4328644808163675356", + "5/7004392815609379855", + "6/4271922055407885056", + "4/8360180124165032324", + "4/6917914460782468874", + "4/7399805415973111444", + "8/5461196806436159188", + "2/1605561148280850122", + "2/5738034821558651092", + "7/4547737423078076187", + "6/4236207734960122946", + "0/5792171940962566090", + "6/4292915581036028926", + "7/2626997649323615557", + "5/585976583606153935", + "2/2428645267432720382", + "2/5525279584172639862", + "0/5570925608414487070", + "7/8477253015918263497", + "5/7710474117164408805", + "6/1184637934537405956", + "2/838783918890716872", + "8/3441635248668867828", + "1/1823855990108677051", + "6/4898534335344868156", + "4/8260964323852612154", + "1/500883403291417381", + "5/4522828981943316565", + "9/2286748847811050799", + "1/7206370633471811681", + "5/3664078105561393705", + "8/7976510894427625078", + "6/3709809206262982446", + "6/3365919421896007186", + "6/4405333618624831226", + "5/1361587051227091605", + "4/2120140406539064754", + "9/321517412057371179", + "7/5627504295104483967", + "2/1780278949271782702", + "2/7583950773220993802", + "6/4701432941733625966", + "9/371759257027406039", + "9/6353099797489463619", + "7/1171189077857243897", + "6/7668591593544518576", + "1/2957148092915483501", + "6/5156740384594767906", + "7/6123730508022793407", + "7/7171519162544743937", + "7/3816874439776906317", + "8/8552867373664754168", + "9/1950103829700533519", + "5/4302417564751309485", + "8/3056109163048768038", + "4/5425980541614731274", + "5/2841280483510288335", + "9/5836403542112948349", + "7/3737027146112734327", + "1/5126288386014760591", + "6/7347865195131284266", + "1/4032717750940376201", + "5/6293212183673765", + "3/6682500824777996333", + "1/5054205398774972451", + "6/3890389409394450236", + "6/6621449028015323326", + "9/5154582935462924529", + "9/9158841882068057729", + "5/6011888263258427785", + "1/1110858683856145991", + "5/7848009137664610605", + "0/4127708130893548570", + "5/2321321357005673415", + "5/36946273189658405", + "7/6059559256291381087", + "0/8399935635225963700", + "9/8197388980339815299", + "4/7207318523410233694", + "9/1488000166673486139", + "2/2781761904764300142", + "7/3222885943297687557", + "3/2316826782328178853", + "1/8224237335829714501", + "5/1806650771470867995", + "3/1951973644454057633", + "5/2622022835040918895", + "0/8655124843476003860", + "1/1782058959245454351", + "0/5506960894140983360", + "7/8963160494162676077", + "2/3536896668346387912", + "8/1901537329741662688", + "5/4220720886298965605", + "3/2264377130171892673", + "6/5756198704870363786", + "9/1519726793966903159", + "5/5678810653598639915", + "6/1454409613918716196", + "4/4179595829365468674", + "3/139103950271970213", + "9/4230693403187455299", + "8/7224596892225026628", + "8/720357837506193188", + "7/5346065413299975877", + "5/6194901173118216745", + "5/2402147865021361345", + "8/1477496831154959358", + "7/8483211489842407257", + "6/7510591361023434026", + "4/7685118125136497514", + "0/8555923729036965020", + "0/3537248622558730330", + "2/3842413169872837792", + "3/8339658783469654413", + "7/2381115272373641287", + "5/6714822956519817135", + "4/8884683832579479574", + "5/4553278294241788105", + "2/769346701253311352", + "3/7466990847504840523", + "3/4160108318202521173", + "3/4219178576020770673", + "3/4847593487392262593", + "3/7281810506456900633", + "4/5641146889402690764", + "6/5761499807583008276", + "6/8932588464752947446", + "7/7317967697392846437", + "5/6419504579501372195", + "7/8195345640450098047", + "9/6412054287486710649", + "1/5147468090878536671", + "0/6103766166425636780", + "4/3144469253563476954", + "5/7308823200043893425", + "7/3178988669591061567", + "3/8213736070689105653", + "3/6636547014861986043", + "3/4465455804351597603", + "4/5707061825687431364", + "6/8926643537682873686", + "9/5836803885565940889", + "5/5660686051540892105", + "2/1698222831185098312", + "6/7084673844829237576", + "6/7669653362294896416", + "7/8355304392450051227", + "9/1171850630247176929", + "0/1204308626386023160", + "0/6387882850882990800", + "1/2799452347753289161", + "5/4088735031226475595", + "9/8763083926990223729", + "4/1128444566661314384", + "7/6906978590312346257", + "2/3913377050439188122", + "7/6731421795261324027", + "9/5089991752525718639", + "6/4459020658655642726", + "7/4414437805779898947", + "4/8101939262133117154", + "6/1507754803652223686", + "1/5025069241572634651", + "9/2302863373756050629", + "5/8404666619472438215", + "2/3254383844226874422", + "6/6873320368584013366", + "6/2948091384265393276", + "2/574020948268347522", + "2/3042278862144288202", + "3/5140870824013145193", + "2/7845258754454952582", + "2/8813781605373071142", + "7/7838222729174751627", + "0/8852604092238021580", + "2/1218633169406462222", + "5/6599394694679490525", + "5/5412745680254105535", + "3/4223813457938072573", + "2/9179174858169336592", + "1/3517390829765660181", + "2/6302356560318037282", + "9/2055499840426314789", + "4/7628061840645677234", + "9/6701039535024763589", + "2/3361560827793155842", + "1/2261229952532596291", + "8/7358734157523532858", + "9/8178163526649339299", + "7/826193757013505947", + "2/1266755173674339332", + "5/1833029413018140235", + "8/5810589326162133508", + "8/7841408059549894608", + "6/2562246756827533656", + "1/6366088887493135561", + "4/924912632080578854", + "5/5655495759695609835", + "8/4651211504449285538", + "4/7658900917447204394", + "1/6304722382660219691", + "0/7454215266761716400", + "3/1258484940638572833", + "8/9077337514745237338", + "0/7307252954150243650", + "1/7823474291543876121", + "5/6045489614841181255", + "0/7580529979354111650", + "3/4640844885998096733", + "0/8992802201046752000", + "4/443067035690654314", + "0/7861433323248610410", + "4/8527978754759125084", + "8/718753450060520308", + "0/6221023046895765220", + "3/2057738389466596283", + "1/2450699869941490201", + "7/7170045828116774367", + "7/8073306076550697407", + "5/4186047212320027925", + "6/6682922107376439816", + "0/5310496717732324310", + "5/232666507592380795", + "4/3580876891848542654", + "1/4715692627736604301", + "0/6925135000804675580", + "4/4204913831720294914", + "5/7410038460426364005", + "6/2706848768261250596", + "6/59862081692461556", + "1/2498635551624382721", + "3/2961941731769882883", + "7/9198227096064884097", + "9/1297020851904805489", + "1/421094532540189141", + "1/9139112501675852961", + "7/8140230598170173807", + "2/2107828546276668032", + "8/4888917034615883738", + "5/8887781982140871415", + "2/1644735624732669582", + "3/8182074329596332633", + "7/2705371733125437607", + "6/8244084094954362486", + "5/6746228570463343365", + "0/1942949634214846220", + "9/4183316037111150829", + "3/8892300636098168513", + "8/7741315968582457228", + "6/9207223743998733326", + "3/2891620618606739843", + "3/3801044327749703913", + "2/3612243526195493852", + "8/9170710115984063248", + "2/1932169806475749382", + "5/5554820983390299175", + "2/2699540574432895602", + "8/564475400205379208", + "0/5361659915399788560", + "4/137956402252565364", + "3/8935388597621188793", + "8/3953399488509747518", + "8/5021483301534543658", + "4/2965203384571898324", + "1/366578311138349201", + "7/4898917940292948417", + "1/2606023011882248101", + "3/6962734362102743963", + "7/5829527216490934567", + "6/3329373307008538536", + "5/307253928048567535", + "8/2189235777683758228", + "8/4782742690457899218", + "5/3817658283546403585", + "7/7610056072559676417", + "9/8385534318598480459", + "2/3422060565122330572", + "1/3916583639080172001", + "3/433053965835849473", + "3/6195586444779081073", + "8/2353307770438067318", + "3/3390948491544641963", + "9/7801519003308732919", + "3/4768727380105932283", + "4/515908221821603724", + "3/5324463923020193123", + "8/4577566970123395598", + "2/994475923461944922", + "3/8477168814437147123", + "3/7243574875622712553", + "9/179010904680559549", + "9/8054815824693315939", + "1/2669884680190212321", + "8/5122961468394650368", + "5/8493512838128713735", + "9/284550144611663739", + "3/4548064921143333313", + "7/7673927614256801247", + "8/8081001157887211408", + "4/8146406346818854044", + "5/5897489474801453945", + "5/5006641165310993785", + "3/5239707242018124973", + "2/5812908027255165172", + "8/6491993481210136428", + "9/4832393639354101899", + "2/2378379463091524792", + "4/2682664085443002114", + "0/1823698582635782310", + "8/6955542992265575478", + "2/2244803574771808752", + "7/6834134988588082197", + "0/6454608773066648410", + "2/3430091746767211132", + "6/7939032245723250576", + "6/5898668897972262796", + "0/1970106291763266910", + "4/1307498976801342084", + "4/5215477834411168034", + "4/3335571488759300404", + "0/3413208970191952220", + "1/7026365034410210381", + "6/2465592345156701756", + "8/9201289911705256298", + "1/842718557527015311", + "1/768919324768057361", + "6/679057094187948306", + "4/7041959720894626314", + "4/7544896513371512324", + "5/4837830456164053915", + "5/1197156250549172985", + "1/5097346020793002141", + "7/8464970102982327147", + "1/8439753341771026251", + "5/1295158831528046675", + "8/3624203384967234928", + "0/6473450152832628680", + "2/2476822817098562222", + "1/7150594894719357191", + "3/5804585806983230143", + "9/4963472722251383079", + "3/2828266897342750063", + "9/1356951550265777529", + "9/2386916687592918009", + "0/9150135379046439420", + "3/8178280687019504383", + "0/802065481456126070", + "4/8996458619853340714", + "7/5385230070911431637", + "5/8567108627715895855", + "5/3536630405604453255", + "3/6971208873296331343", + "9/6113711806364313389", + "2/6063035646940764772", + "4/6962397801887831534", + "4/1295365766487169754", + "4/1776389680341435204", + "5/5184553353543437525", + "1/5484256336369926941", + "8/7040020501493999598", + "8/518544591164718858", + "3/5143060280685644453", + "5/2074878895208021825", + "0/4669908084423878540", + "3/8758954278605905323", + "6/4971642171119466256", + "0/6671456705428313660", + "7/5951201133783656237", + "2/9077180609196175992", + "3/3319313235420590073", + "1/1316856962769065381", + "9/3466451508114073389", + "1/5818651445208846641", + "9/8893237786938847909", + "8/4625165756079676278", + "0/291244906357378860", + "4/3806128401068099214", + "3/5886081538831140633", + "5/7476873940525632395", + "5/8745294407886522055", + "0/4348063833041321880", + "9/2374379414263373949", + "1/9206462664930997561", + "5/2892613434586160725", + "1/6310204011248384401", + "6/4376695424868811466", + "9/6556137179273500919", + "8/7281403657582898198", + "1/7685435301043087031", + "1/6824594261930153441", + "1/868804545455835421", + "5/7833611616462181245", + "7/4122299558086276137", + "6/4451042476919442106", + "6/6964081123275674556", + "6/1797193551406740316", + "5/9008347109693106475", + "0/5335154791579029470", + "8/6442990490810680878", + "3/2504717117135205053", + "0/8579990441125748430", + "7/29852354032784907", + "1/2765312621524325171", + "7/9098042132757264377", + "4/3372929349124117544", + "0/3354667489182854040", + "2/3740303402669683942", + "4/7991021174847367064", + "4/3339900600119240874", + "8/3386521784415408858", + "8/5614480241065670608", + "5/3771106190762672415", + "2/457872302150102492", + "9/8376957774032077449", + "9/5008760186156711419", + "1/4023337404581210901", + "0/8636336199608159880", + "5/3854217257506355275", + "5/9199700797858973905", + "6/2703069205902810656", + "6/6525992654126718736", + "7/1699248471743156587", + "6/9096015633197067846", + "1/7957235258276770791", + "8/7184801964362690268", + "6/661693013398942696", + "4/5377088408982230294", + "3/6089031061085053143", + "8/6585757893434104238", + "1/2329300086747337941", + "2/6534563580070628532", + "5/1578111408458547655", + "4/1953923925616065444", + "4/2193677410218863344", + "7/821888738907970467", + "3/8296849551686927123", + "1/2237352552856169001", + "5/5882827413926639565", + "6/1394569083778097496", + "6/2115822121303717686", + "2/6877947397956032302", + "8/7376382623569512898", + "8/4654706480040462638", + "2/378469695762494932", + "0/1237575205310118730", + "7/4466224657171187557", + "1/1345723906856866801", + "7/1249591367438663847", + "6/2459858633554565836", + "2/5050123600001026332", + "2/8175478316754801862", + "3/2099147160093302673", + "5/7262602741470117395", + "7/2305192704446027327", + "7/8386384849472919447", + "8/2534430405998673278", + "4/4246412337225082244", + "3/4794867160112113133", + "0/4668883329338734970", + "4/4326650625516406144", + "1/2014862516968027501", + "2/5602655209095887152", + "3/3340022619121559033", + "1/1167488175366488991", + "7/2767149924011007567", + "2/2476885548016084062", + "2/1193688469146821962", + "2/707130625220206882", + "5/2693331771704007155", + "9/7760808489170101569", + "0/508510521668523170", + "4/1842605408097340554", + "5/4422821996291298355", + "6/909052286420907286", + "5/2198537900601776305", + "9/833664793274364149", + "2/2660363650954301992", + "0/3477999226446069490", + "5/8032096353486220645", + "3/6547056222955912353", + "5/6819867849896987525", + "9/4186733546962262569", + "5/4307275118727868815", + "8/7095474761823281308", + "2/7222776562084588132", + "5/8631084120612983515", + "8/5465911918320280638", + "5/7915071496332830255", + "9/2514532985197063329", + "6/1606148909371871476", + "9/1338243922690731899", + "9/8975168090610255209", + "4/4935965187241321094", + "2/8726551556669143942", + "7/660692537718071127", + "3/1347519616929655733", + "5/3393757543748390415", + "0/2435791289321351400", + "5/6079196133306282135", + "5/7244249269411601105", + "3/8381795627069059153", + "7/8321153069906373897", + "7/1484103586883380477", + "2/6885193412036678332", + "3/2987290227824946943", + "7/6622026306480328647", + "9/1154519867240369999", + "6/4193892742227932766", + "1/1084485091609565291", + "5/420309691542404725", + "8/4962207784002211678", + "8/7699513731384534368", + "3/8466290386245649183", + "8/2574417938257147578", + "8/5250954374967841748", + "5/5364751601077698895", + "0/5096645626300805740", + "9/5048753971200955799", + "5/936557262981337495", + "1/6848771811098796991", + "5/8226923841749003735", + "7/6869978987702320197", + "2/2703765365333054272", + "6/8544705845519015506", + "3/5520689415801143193", + "3/7588289129870078203", + "7/7109045322179879187", + "6/8180243190202277026", + "2/4642201604365203782", + "4/1957734253971061644", + "7/7730475753182935527", + "0/5253400920378194250", + "5/3193316770724281905", + "1/7707554809953942621", + "8/1021175346465024988", + "2/6328984370391386842", + "3/750674783717256823", + "8/508407034689553148", + "7/458328927745656987", + "4/6827552946979712614", + "7/4104994908429929267", + "3/4976804457047767323", + "9/2216215651547790379", + "5/6246294461640755505", + "7/2346529205465560577", + "8/2488891082724828288", + "9/5118463949403770219", + "2/4952838935961765562", + "9/7672446488326789799", + "3/6420948867901305443", + "9/4475146077870953529", + "8/9147101315687906518", + "9/746448782768563959", + "1/145725659114903621", + "5/1902169993874119365", + "2/5730363034644772862", + "8/7716414031605486778", + "4/6080851945430092664", + "0/7788283479966602650", + "7/2010498885251420547", + "7/6415970436941813397", + "2/2181708071175501522", + "6/1938275358826866296", + "6/773950284657045746", + "7/8170434775370703657", + "4/1049895711055319984", + "5/4479483670611362535", + "3/1388280846970033823", + "0/6133982799039350770", + "3/1777444296321272133", + "3/4807129838418324733", + "6/3591413415216083676", + "6/8719564895155468686", + "2/2058791483291574302", + "6/2066575287576899836", + "9/1824419316267727569", + "0/1678928977962611590", + "1/9047019575018736111", + "8/1360799490699591418", + "4/8895120517118930164", + "4/6316645998948470434", + "0/4089658609127385490", + "6/1075299479700093996", + "4/4848287659060022944", + "0/1875967301299573490", + "8/459007855788249828", + "2/4802428519434131022", + "8/1357961712977444488", + "8/9050026826592292878", + "8/4702538119255336598", + "8/923601331113087438", + "5/6193186437518230755", + "7/3930750181567355187", + "2/2596688598359026732", + "0/3960375644736155140", + "3/8097783297422092533", + "5/6419369785546160035", + "4/1870019208289738734", + "6/750028696459758666", + "9/3687543789074113029", + "5/8516461219449026805", + "2/6435575656720534922", + "7/5783905767538357737", + "1/6949387271484764791", + "2/8880156118342875802", + "0/3854399611640760780", + "7/2420392349025520437", + "7/7954517385398570697", + "4/2366837415828364694", + "7/1528811595949746327", + "8/8058314862456877568", + "6/6061221994007754956", + "3/3151577116591927253", + "4/163057664431396564", + "8/772576085689818618", + "9/5733466919185387939", + "1/1470645320634805751", + "8/6841032592420201498", + "9/7819325850027263909", + "4/3480692677615746314", + "7/2541839886650256977", + "8/344511030690802058", + "5/2037412370913896575", + "5/8173251886585432645", + "2/2231721480495135272", + "3/1105448705871304793", + "8/8339549908360843948", + "1/6425226991514907941", + "1/8864752956395701931", + "0/5274355457193658760", + "5/1270849047285000655", + "5/8750276723719137885", + "2/6374088771219136692", + "1/9153852317965904581", + "5/4024628882430479495", + "4/7694539601874692404", + "1/4766645776652438681", + "4/6289146358896397464", + "4/5651581114885857194", + "1/6899390556330075631", + "3/449581594761350973", + "4/7279202903900497514", + "6/1434478351037989816", + "8/1139245593692146098", + "8/8655155168860585468", + "5/8735922569712229395", + "4/262243485062015844", + "9/4701942209544288149", + "8/2742321764367843608", + "7/6250676765746406547", + "3/8957202597900553643", + "4/2326297153229981524", + "0/8084875826136800310", + "8/4654147628845351378", + "4/1720651895468610074", + "9/6194212594144523109", + "6/1211428940362862906", + "8/2725259867053063388", + "4/6651976130524224974", + "7/2875574159722699637", + "7/6085547626131616367", + "6/4592969624485808006", + "8/1646396092298308428", + "6/6768863934045021446", + "2/8183574163637716952", + "4/2666090521927577374", + "6/4116864908259275696", + "6/8606851683125551746", + "7/4335149419828248437", + "1/617047807225284841", + "1/6159490911592357381", + "6/1838313086045286666", + "4/7967611415844193024", + "7/3291539818879534117", + "7/8361149179115464327", + "1/5166969425306746661", + "4/1359795604149854484", + "2/1469109509926212352", + "7/5955130656301359107", + "9/2322438278279145849", + "9/3596135964007665709", + "6/6288106105833109396", + "8/2012993068473991228", + "8/4244031877277870938", + "1/8363458216134121601", + "0/5260772532742320860", + "7/4882993236068094817", + "8/199536918569402188", + "7/1237283141419043777", + "5/351669680155224305", + "8/7057591798548590728", + "6/6333206539532878256", + "8/4046876169699421738", + "5/7355267507276936085", + "5/4185422473663177035", + "7/9025978660075183817", + "8/7437750021110899298", + "5/6703175165777610935", + "6/7858534422222469266", + "3/8616304817685763233", + "9/5045521068953168749", + "8/5701348480887470818", + "1/2238856788173553201", + "1/2290091540278737091", + "9/6785985563756689109", + "3/1725187267581702993", + "2/2494905934533282342", + "0/670119082639015540", + "7/5319411785513282797", + "9/9042052786909263889", + "4/1482250606042009014", + "1/5254957221394100321", + "9/4518803046037756169", + "0/4828831661202392800", + "4/6937302608742960734", + "9/4552966265198168649", + "8/8635714222914570888", + "9/7392474239889281979", + "7/7420377494869812317", + "9/5087078650758755439", + "0/3036069849554776270", + "2/5829433524669364732", + "0/7226930488959340220", + "7/8217561518298576047", + "9/8811421793700294459", + "3/7973416455626790413", + "0/3016030122582984150", + "1/2466832013530520951", + "6/2263602286248079036", + "6/7342201451517479646", + "1/6844008361946595361", + "2/1338801038429127472", + "3/1065487795967695213", + "0/7117014808644114040", + "1/9103979785208915011", + "2/4280292229247255312", + "9/2214155684794187589", + "4/4774814859001674934", + "6/5815747076983473816", + "8/3059156907862941648", + "8/5258366722961919058", + "8/3674012883267743338", + "2/3703708443438255072", + "6/473345293169779676", + "1/4157543400861728711", + "2/808648667813071282", + "6/5768857300739781156", + "9/733027955218094879", + "5/8113894263525122315", + "6/6725507440506228436", + "0/1121781858642647090", + "0/5266023120554637260", + "9/6267444265924702959", + "6/4926442228997475166", + "2/15504596758204702", + "9/6671614023999686559", + "3/5103592228707312313", + "1/1729823700308025311", + "4/153124832013860744", + "9/685705968081464799", + "7/8287972773811322727", + "4/8212750148332098914", + "6/280050584463676846", + "6/3475371557909560956", + "2/3902295136748791102", + "0/829555908025775110", + "6/6796084189987263166", + "9/2824045810243983269", + "2/197737947464982862", + "4/1850828603297939764", + "9/1548348318403230369", + "6/2926759784878283666", + "5/793948168127292755", + "7/4401491402418227867", + "1/1488959927267435271", + "7/2745635656176229587", + "0/4171344973810277540", + "8/5030476052653147988", + "5/1245840608576904955", + "8/4350470469862007488", + "0/2630683630726711980", + "2/3950414239048342012", + "7/8252201658297697137", + "6/8900690735178904096", + "9/1936855123906873439", + "1/8720550668668745521", + "5/8424866659979349015", + "8/2095154563878043878", + "9/6170097154467896929", + "0/4444114264114368890", + "2/2762779297260194322", + "3/2664500543553668343", + "5/2260237456445339295", + "5/7411584371823448995", + "1/8884420168011483831", + "8/6629096182178090708", + "5/6217492724547130175", + "4/1437513874880828384", + "4/1481553325536802484", + "9/3999087338132655069", + "6/7077457179039648086", + "0/5412180471813092590", + "4/1367469005064537634", + "3/1876191956137745193", + "4/3683125818411978604", + "8/5164307594248855538", + "1/4094843865985168161", + "2/4548175043937280522", + "1/8561766676197378971", + "8/6971171647087923138", + "0/7640996416344278520", + "5/8050727405164315235", + "7/5161288922716646227", + "6/5479694836299545406", + "2/6170776114896768792", + "0/7569896055100859620", + "2/5835894938248970422", + "8/225275728816734828", + "2/6341340029354008602", + "1/4304276602398713151", + "8/6915149458914697648", + "3/6027843880774122853", + "6/4760699826622104346", + "9/1119974759452375419", + "1/3672973998338587381", + "2/6213580497061254232", + "9/4318735932402329499", + "8/7820382991539224698", + "8/192735332970899168", + "3/4633554672541925293", + "5/6469353235575173195", + "1/4291452524078531911", + "1/7242121163606281221", + "1/7478235038416289591", + "1/7408253555848629701", + "0/3101683195054421640", + "0/9116221112082313570", + "7/5578800341934720877", + "2/4276529726596124622", + "0/8322355542207328910", + "4/8823478936807925074", + "6/1316469771857687736", + "0/9206822890797766450", + "1/6351923201278401701", + "1/6986514431164740211", + "2/3384386998584881692", + "6/5608412652141101596", + "9/1115378098487797119", + "2/6079258846876013002", + "0/4425303448078495670", + "3/2634735906261091363", + "5/1050326709584698115", + "0/7496820036356810100", + "5/4409099180279061875", + "8/4662307800951792628", + "9/8141146076908198289", + "0/3311279190067488310", + "7/8340603456956594057", + "6/6699937455128503126", + "7/2425144324935212807", + "2/7949664354539262822", + "4/9001649085681893884", + "9/116017486125927549", + "8/5178821756743605828", + "5/8947047521176922535", + "4/8856044865177459044", + "5/2985295137272701365", + "8/6316663553683421938", + "1/5828638768289765721", + "8/536481659629638098", + "9/2191890985424372609", + "1/2285085036502704561", + "1/8096636374428877351", + "7/8897876899518864647", + "9/8957003571325873989", + "7/6768811104704627797", + "7/879733534380679497", + "4/5469599571420815204", + "5/6147014317871929145", + "3/1450841627953374253", + "7/1764981851681763317", + "7/6470524732227744167", + "1/2724436324832285031", + "0/3590506110366233050", + "1/8171777439898545661", + "3/4728961955850264373", + "0/3700946007225945100", + "8/5201648749988125618", + "9/4343999714985716909", + "2/4456112358993066812", + "2/3750449084033035282", + "4/5606449916753595174", + "9/6523528233222660789", + "2/676424401513733582", + "0/3692697510310972960", + "8/1010164970508114068", + "2/8547681707374511432", + "0/7186714540355579500", + "4/8781439621958139124", + "1/5255890132751733521", + "7/7011649075079917317", + "1/6011954728924672261", + "9/3926330630156238389", + "1/7275037842047518381", + "3/8158112097198995243", + "3/4898160243500542813", + "8/7591501581530915648", + "3/2467870727196350003", + "4/5199126144262596784", + "6/862093770323452946", + "0/5559540556040158120", + "1/3428569214058220471", + "5/1544084768348700335", + "8/4373018240817664968", + "9/5738830962273183079", + "9/3079414634069463889", + "5/2849519965405809515", + "5/100757673924144965", + "1/388568139399896921", + "2/1931046980233671822", + "9/6506384108849315199", + "1/5918763884300868531", + "7/5889719155839068747", + "6/570741583478944376", + "6/7969126021929741796", + "0/8247661979959778750", + "2/1965558464011950702", + "7/8908162137462845137", + "6/1227925104914215646", + "8/3428805080881140908", + "4/5472908469123373324", + "9/5759453746565143149", + "5/4854044433283425235", + "5/6009586923078342105", + "6/2873425457065687886", + "4/3859605156773608734", + "9/2352337387517798169", + "6/9190283105924502946", + "7/6072668541334716097", + "8/3002232654989868638", + "0/620994320709587260", + "6/6376492816159404126", + "9/7892385330936620229", + "9/4885186360077738359", + "5/199532219632004865", + "2/6836845524042820872", + "9/2614395055537425529", + "6/6240399567662669476", + "9/6958862413958157239", + "9/9077928203297879159", + "4/3351383747652567514", + "3/42590255319964973", + "2/4615447451218854292", + "6/82397748315020266", + "9/1963557260536905449", + "6/6844432738121360946", + "9/2403325891759122679", + "9/4155349681049851239", + "9/5269826979999108359", + "1/6402374552112509851", + "7/6186867876243081687", + "1/8026618019094349771", + "1/9168839217746999081", + "7/3437676569737709757", + "1/4534369741528001881", + "1/7028968776990452111", + "5/1110407157561072225", + "4/5757695720456163634", + "6/6805761552133397826", + "7/1879560145283243237", + "5/9051793373542728185", + "7/2176382610451430317", + "0/6743283466012586870", + "5/7255124018008749505", + "5/2200697853073864545", + "0/6856454735337972290", + "8/2428792447156783748", + "1/2333704213346265971", + "9/8007165153322016589", + "8/8096127840625049968", + "3/7793166869127228933", + "6/3019032206720320136", + "0/7775487922942112530", + "6/4332823773099052766", + "3/2478423076377888783", + "9/6122598442725719139", + "8/2326559753946225638", + "1/2916283856073862521", + "8/3375410804234871198", + "3/6731544005085233813", + "8/4354684906503592898", + "3/340275887393271303", + "6/9007376986550208966", + "8/2089189264268551388", + "4/8938618816646420884", + "7/4621436628439392957", + "1/4112746265717640561", + "6/5989289570069493686", + "9/1811410429344529219", + "2/6267472578381549682", + "4/4771849138405493154", + "0/2258996553701833420", + "3/8385999960920044543", + "7/3669222752640322207", + "3/157858796425757033", + "4/7791104668696169014", + "3/4945477734240634483", + "8/6966720435411630748", + "1/4227592742100161971", + "5/8315368648658741815", + "3/4943499651091960423", + "9/723310153751614849", + "7/7163128393510521487", + "1/5081221605969189781", + "3/7745411840772119133", + "4/6318274117798360394", + "1/1091836380449953251", + "2/5399738367594727102", + "2/7587518925221192862", + "0/4484496397092209800", + "6/8896907062433281346", + "8/7482263378841828178", + "4/7345066945839319054", + "6/4551778916214753376", + "7/7278973235510898417", + "5/6792369265232745475", + "8/5864810961784525938", + "9/1730758009803655749", + "8/949460287964045308", + "7/5614924935276042367", + "7/8722608355994960777", + "8/7949392143680085918", + "3/7617993776727431863", + "1/1115769747621879241", + "1/6099896598579578211", + "4/4094489117718266554", + "2/3117162008846844982", + "9/4905039529788662259", + "3/5693575368601008753", + "5/2667796186180045845", + "6/2007857275381057756", + "4/3138704850234273314", + "9/2640814396275687599", + "0/2200047488261504810", + "8/1178033198305632418", + "4/1924056675340464694", + "5/8407692057009371085", + "6/6614891890347686806", + "4/8067759664895399224", + "4/5697606268980424114", + "2/7766577925281039142", + "9/8747375432329205919", + "0/7187692590181794870", + "3/4882471568671269023", + "5/1256423794048308195", + "7/7264597513901096567", + "6/6324379653819689036", + "1/3434622083141981891", + "2/1893630260382368312", + "2/8329994320024178642", + "4/5423369656172284754", + "6/9108190571980993476", + "3/8504886642341670343", + "2/1251381903099536212", + "8/4448925825425980268", + "5/3126770821879845525", + "9/1475931383931714879", + "6/1242910270656724866", + "3/9158462873877627433", + "5/1131385103741323195", + "8/117450483238350918", + "7/782853793933106047", + "8/562734480447501468", + "1/4717160465174880341", + "4/7575226968292712034", + "3/5540146588867710143", + "6/6891219621328883506", + "4/6793359831420938854", + "5/5940334157936316385", + "9/2327877117095801269", + "6/6745512609388674576", + "8/2684570953025110388", + "8/6825255221697002218", + "5/2692771187038708175", + "0/2686648446743840950", + "1/8237465771239986691", + "4/5935878506303958564", + "2/3572092017725981412", + "3/3761873271851746773", + "7/579275009781005097", + "2/8886539608329094002", + "4/1147644776127341834", + "2/5669673979520114392", + "0/8169546554462796720", + "1/1004923204540354271", + "3/8152518989759762353", + "5/7674953492479852325", + "8/1522824285689184808", + "2/7185547084027344412", + "3/6660199431343688633", + "7/3632232380688737927", + "1/4485246792197913741", + "2/2817868232393558092", + "2/9202718670589160852", + "5/6035071975896458245", + "6/1561265850291947906", + "2/4523650756064879772", + "2/817193750435006422", + "9/8783235118480392209", + "5/5053796267521897325", + "3/7071339235194606663", + "2/4756461375945957302", + "7/9105993487284200357", + "5/6225510832874111705", + "4/662119430240220134", + "2/5954800965537759402", + "8/5297490946425243588", + "2/7337606834557334792", + "4/2315136494910371244", + "2/56092664529491892", + "1/6980020556862390631", + "8/8181112143572892128", + "6/6647325587438978766", + "5/5937793192137169475", + "3/2241201055298170623", + "6/198656243734135006", + "5/371376302521563495", + "4/5787737517460082104", + "1/8080183241024179101", + "7/5478344968181588917", + "5/3037318618753251395", + "8/959607488962680788", + "2/2612640450310948332", + "2/4918514641727903092", + "0/7472158115440181920", + "7/3774413198004757757", + "7/5850067906823366797", + "3/3038200027433502993", + "8/1606177666132689598", + "7/8836425410264379487", + "3/6471552871136847913", + "2/5090829931256376422", + "0/6449202507101901550", + "4/2850265203210483144", + "6/1140262708212714286", + "3/7951297484223540723", + "6/2071879600195190356", + "2/6772399800383382862", + "1/2895155425520947971", + "0/4250111571597570830", + "1/6496264550199275411", + "3/8658676203545783213", + "5/8552347860415963715", + "4/3655468612600055474", + "8/2516192218304772188", + "1/2504214647955224731", + "7/5854108069794625067", + "8/323918320243422518", + "8/6561336928818358718", + "6/4803867649293376176", + "5/40303746515574645", + "9/6230995262742952009", + "8/7694022789167241338", + "5/1506378956623965565", + "4/5081600001797832684", + "3/4665808909041721863", + "1/6486433105600751441", + "4/9190200764783861544", + "4/6740191081587264014", + "7/7138744985808827077", + "8/3127037064847280788", + "7/7093613152253468347", + "2/3999869022383286702", + "8/1494822555572441098", + "8/3389063967568750908", + "2/7768897627191864892", + "7/5718326538184083497", + "0/2102326503402213750", + "7/6485318392400424227", + "7/3947991631103385387", + "3/7647361726437852893", + "0/5033935592668089070", + "2/8832464031080848282", + "1/1893231049178374301", + "6/2261505926203110716", + "4/8498589313915246114", + "8/4606818330863612628", + "0/8912744656063950900", + "2/4363313336577191702", + "7/7549031683730271867", + "1/5946347889059137861", + "4/6644130510462203274", + "9/3589559925463233129", + "9/7516460402799900619", + "8/6655828653583331278", + "9/7804431522640603499", + "7/8055313264896896707", + "5/6462588948507580975", + "8/639116532768323528", + "9/5489585137277426299", + "0/8420197695032137410", + "8/8572271450590463728", + "1/6415344083713197721", + "1/4999430598432360981", + "9/3836257137992775819", + "4/3987640659823808694", + "0/2613279201202494740", + "8/5326815663636169378", + "9/2356392786899127789", + "2/8021249752608972102", + "6/7741537612597725566", + "7/8597117293706619517", + "1/3841768816536593591", + "6/1039427964708263906", + "8/3279993666292852188", + "0/2990096815206295500", + "9/6026819817247288949", + "3/6328967473283406923", + "9/6970209110421482299", + "1/8113702739376202201", + "2/4303009138630675952", + "1/6775388655547168091", + "0/3941226150071197300", + "5/2163058967642495225", + "4/1756412906847705434", + "7/5477651714106780337", + "6/2787298327916993306", + "2/3584259713792915132", + "9/796365033248191499", + "1/619179691751808401", + "5/8145331942239552835", + "5/2219521776142849355", + "8/1024422226486982498", + "7/617698273435044707", + "1/4629773375945291691", + "2/2171778490981022582", + "7/1722159858666067147", + "5/8422811613549810895", + "9/9004584026960729049", + "3/9031374932177534443", + "6/2174479970970095776", + "9/3482116851746459579", + "6/8919373056477977026", + "5/1867173035943809745", + "2/2421337897333296992", + "4/9019984609091856104", + "5/9130607561162966885", + "7/8672279741106636407", + "3/4149422984149986433", + "8/3207585563428019288", + "6/8701095879618518876", + "9/2048868671807649909", + "4/3527654599295990674", + "6/7809552033643928686", + "6/4918676849154205486", + "7/5594416508214365127", + "4/6918533285950808734", + "1/1897370026260194541", + "5/4670092589235638865", + "4/5505867222307235734", + "4/8874485777266582654", + "1/1973442971947694311", + "4/4404716863128361944", + "2/7497494977171689882", + "2/3371613557403409272", + "4/3657543033458435934", + "4/7071401529780218484", + "5/7523531168794599875", + "2/998659632641323222", + "5/857503026630032155", + "0/6390146497242866370", + "3/253527076065677013", + "4/1174382270511447454", + "5/3780402224657459085", + "0/7881531798894924410", + "0/1391755575005024160", + "0/8680200687059971760", + "6/1796630085369687526", + "7/6982170113221956087", + "0/8927342583570168680", + "6/6553509963235867926", + "0/4830594696653672490", + "1/6437805614763322251", + "7/967171104010729127", + "4/3332141018485864234", + "8/2805384997438054148", + "1/4903761066018256881", + "7/9020115637089556117", + "8/936738755060014468", + "6/6591134403994114256", + "7/377071434365134617", + "3/3007644601428214023", + "5/3253749808490466255", + "6/1609423566635633906", + "2/7971929769523463612", + "5/6587150080161605035", + "9/4299043226370805949", + "9/2874068874974428599", + "8/1017183490860990678", + "9/5394399568503288069", + "0/6071811945676480190", + "3/2151464059920264983", + "4/7689457867860873324", + "6/2403931910650553566", + "3/4750948401968006753", + "0/7556769740489959430", + "2/1308745360919045732", + "0/6534084490972763740", + "3/6212357532929727983", + "6/6330891695893314096", + "3/2412011419870194283", + "3/1574461740376798013", + "6/7027973117472935206", + "8/5194201577649352808", + "6/8710056029858060766", + "2/1930811647810428862", + "4/5369832494458033364", + "6/7358924166989685446", + "5/6689824559897953255", + "1/5853862476901312331", + "6/4066838911086423396", + "4/5367509602093070954", + "8/4018616569051336518", + "5/3180968043253260785", + "2/5334148182589045222", + "4/5785700069168527344", + "0/5982820027530616870", + "5/7440607042439536205", + "0/4121398226677993070", + "9/3614337930854405579", + "7/8204191153897033937", + "2/2608418729693350052", + "3/2758537742234432853", + "1/5133685981891252441", + "5/3086148613954416585", + "9/612518061007481089", + "4/6208224055638750224", + "2/835088289985553372", + "3/6591630466089427693", + "1/1545127612410415661", + "5/350163663064200445", + "1/940040357600638101", + "1/5865086699688490811", + "7/4650072623861625327", + "5/402941082793586535", + "1/1058315739824695591", + "3/1963530065477614143", + "8/7003115368036222378", + "6/9192720680039779406", + "5/2571449078961999995", + "1/3085136864384415221", + "4/6023903481246448204", + "5/166563508434482075", + "8/7453229097610099678", + "7/1379109529896518287", + "9/3955732223091731529", + "8/4651908154524962748", + "1/4701150385188685551", + "1/5684535633769488081", + "5/2766606107738603835", + "2/5229998744049876762", + "2/8576283354066619342", + "8/5090319596306983318", + "7/6138413638237989397", + "9/7835487710745946639", + "4/8369554683334999354", + "8/1184181077667647678", + "3/1442290922395124763", + "7/8531600721206853297", + "1/8196429873185562591", + "7/3854371172595276927", + "1/6608679800900581791", + "4/1542457042457108194", + "6/3969608381562128036", + "8/4952647400923854788", + "5/8765649801817931765", + "7/5505639780925884087", + "4/6224181708507224014", + "0/4014507138118884940", + "1/1926465354138495691", + "9/7318523362320807679", + "7/5655297031848751247", + "8/7495877693266370198", + "6/8817376205683152716", + "7/1655207796087338957", + "1/7370521945282859161", + "6/4488905162705238896", + "8/4061457177374126318", + "6/4561658199146546106", + "9/3535631793266073839", + "7/6996692214182485857", + "0/1150871373227540420", + "7/5541032730605397977", + "0/1785874600659089590", + "4/7740737204873827184", + "3/9045430738422343863", + "0/1454391595792024990", + "0/5812530083332159160", + "4/6109644558641171934", + "8/2230197925036842028", + "7/3668161730048930967", + "7/5608990758730725877", + "0/7632387832653148790", + "1/7055371745312478211", + "2/2580656835151788772", + "5/810097974355430175", + "7/3548333452369501977", + "7/7082291171348491037", + "7/8976266803829290597", + "8/538147441349797538", + "0/1751718851800353870", + "5/7905476443108696165", + "5/1379665771528163375", + "5/8156606713071278675", + "6/4885119104200467236", + "2/7497513472373697872", + "0/1120169165975032970", + "7/2927712622941703557", + "2/1617350706952550452", + "4/518220629359550044", + "5/3073005548923964945", + "5/7158568933274775485", + "1/6056622977699794131", + "3/8715871309764737423", + "7/5250038153480526687", + "3/3757514141713464173", + "9/6119522983111913549", + "2/5735140442319700032", + "1/7336157212919190351", + "4/3722743615576184604", + "1/3938053963121257821", + "6/2775703195622881576", + "6/7794372685650658136", + "1/4784470651049419781", + "1/5775331282393679171", + "7/5638300725711077287", + "2/6775684926108168842", + "8/8802617316804881058", + "8/9170394299338406128", + "5/3601121988797006105", + "3/4742822505195047423", + "5/410310902790747495", + "1/870507587588303821", + "6/4173090343132444986", + "3/6164110345576644123", + "8/4227483740474449028", + "8/584371434514107988", + "2/4241220429075608092", + "9/6133936345830358259", + "5/2587602336364978065", + "5/5323542932729557455", + "9/3885106886282085269", + "3/4649608112536448993", + "3/4069442913316567523", + "3/6886541766924128273", + "0/2535787320550505460", + "5/5512299383949221015", + "3/2790598426615737153", + "4/6648436694793913434", + "3/2010086882502103453", + "9/9199834772793170489", + "1/1839775362948760501", + "1/6042660578361825281", + "3/4853960865412697923", + "5/6882627477254976545", + "7/8258530900525317797", + "6/5018467555172224596", + "9/1979612204711881669", + "3/6820830952329841473", + "5/8332147483870134045", + "3/8380297858480414663", + "7/8258425214213383027", + "4/1129372759822483274", + "9/4252282544438245579", + "3/8135316096050396633", + "2/5594867180276542542", + "8/7031160447962781048", + "3/2761604229848006273", + "6/1770969254178168946", + "8/234693904232351858", + "5/6082061628203004175", + "6/8619650956882442056", + "0/8987093455296006360", + "1/1896353441202863211", + "2/2562185340176823432", + "8/2011793221759338638", + "8/9077876338219475448", + "3/5986059134672120443", + "3/6232619294152759233", + "5/3052318023178166285", + "0/4241445133426269310", + "7/6500386753186498377", + "3/8421423456602614583", + "2/2600413665397807592", + "5/3668270348889094215", + "0/6464398875425309810", + "4/7473363650738559624", + "4/8397418225075778264", + "6/1147311976714696196", + "6/539587857615219806", + "3/7810301450033839893", + "9/2429567926287357729", + "3/535314897363476913", + "4/5642576599883154994", + "8/4756060593283959048", + "1/8130519058665699671", + "0/8385165903303542950", + "0/6923907163634854460", + "2/3890149047199706732", + "5/4571783602257077505", + "6/4983199983005692516", + "2/584243721948236092", + "4/1275221937888675694", + "7/7844008524188649967", + "1/6263569656144205581", + "3/3760642930810643713", + "6/9135158949291689616", + "4/7534109278635576804", + "1/3768881479776461831", + "9/3212883654502286909", + "5/8194607395753728245", + "3/5812624248298468103", + "8/8643969398140227598", + "5/974847084634639045", + "9/2466830029194403309", + "3/5759217260781714893", + "4/775839305180153794", + "1/2978726768080688841", + "4/6575575988497187954", + "4/2137814368426259974", + "7/8760674251911688127", + "9/1368585944037385529", + "3/1649010589249137493", + "5/1927412575898895245", + "4/9181752815174074674", + "6/8529419900383228446", + "0/3552976378185837890", + "4/164816417633468004", + "0/302946235899755680", + "1/8071225484420498401", + "5/4673719848154245655", + "5/6275331455457888475", + "1/6207815496521628711", + "4/4527177686805123514", + "4/2940684392054129664", + "2/1114201452084140282", + "2/1994851219558900342", + "7/8672688523242998927", + "4/972714918529738804", + "9/8260309085173765639", + "5/7239028683773972565", + "8/4876548403531686188", + "8/380109263013960278", + "8/5738531055154412728", + "8/4938554759069764028", + "0/3411627188325814460", + "6/2452017381919292466", + "4/1309620416716834254", + "5/6283721902222867925", + "3/1166579251060663513", + "7/3397502466712582407", + "7/562472628488327997", + "3/2583900293180231763", + "1/7024767757390302661", + "7/5870384789332664097", + "4/4942258466275469264", + "5/7854939041815049245", + "9/9116818993398670839", + "8/3307917284630279618", + "2/7804592879366353202", + "4/1510675542018136944", + "1/372017696341005441", + "8/2219618719919349868", + "5/2409000945888285955", + "0/9126525980688830890", + "1/1370764694249308691", + "9/8946325325525795089", + "5/4787552419692571295", + "9/7643997173703619589", + "1/8071866160991178061", + "6/651932552770076566", + "4/1746433897647684184", + "1/5668698929381603511", + "9/8993237433319540519", + "3/6483630865247989743", + "8/7091012077107151268", + "9/8180962091454486119", + "6/2245767620298579616", + "1/2070662986456721381", + "7/8465586655126213017", + "8/7501775172727470598", + "1/8533264211994820631", + "3/2603674588060190993", + "2/6673112516343973842", + "3/5921895508919975663", + "2/1361522920851995172", + "7/1373907050617161787", + "5/1476619738614683245", + "7/3941000891086446027", + "1/6854277842758547861", + "8/760676716597887258", + "7/7407126950568681857", + "7/8839528177884749667", + "1/8793983573847096751", + "0/5843572872469189880", + "4/6042571052532137804", + "4/6329823776576556994", + "9/9039618864457648839", + "1/8056943261237151951", + "2/6781928617205810162", + "8/5280204418074437448", + "0/612584390295458060", + "9/4397549862117197469", + "8/8522340899849638248", + "8/894054761481193788", + "8/6581761659983265818", + "8/7557821104113531078", + "2/8055198646624159272", + "4/8023340313520017564", + "2/5721112656327594582", + "0/2233123243502371680", + "4/8148693544763795484", + "0/1973087003520799280", + "0/2137248627372066400", + "2/253500089699158942", + "4/5277884919433851414", + "7/5226754176773440347", + "9/3115594321255474619", + "2/1085098988521224662", + "3/2743341926469026283", + "2/4983511911639454242", + "5/601444595769148205", + "0/2474288062315403430", + "2/6972081213680460022", + "7/73411066158979637", + "4/1980292046866174444", + "2/3417956562734087162", + "6/5403757737808693686", + "0/5085527463429017380", + "8/4820704315731528918", + "4/2020431458010300094", + "8/2902614432844749478", + "4/5165499202246879674", + "0/4408244764032758140", + "3/4651638588223069693", + "7/1846000824917760117", + "8/4747322828625390648", + "8/5858928099574775598", + "4/4353434294685314914", + "2/1337760094983144172", + "7/1734054250815685737", + "0/7098852939373912760", + "4/6616489800756417634", + "5/2892145959216557405", + "9/5820455622283630369", + "9/9208956416626396209", + "6/6875738992607401876", + "3/4029457123324950633", + "1/1675920790798262011", + "0/3055330343117648230", + "8/8628843513319695498", + "7/57237349332826877", + "0/3117819230832470450", + "6/1361313827230719986", + "2/8402225258248787622", + "8/4976438312213784818", + "3/412628541458115723", + "3/6384163679516639743", + "6/2878494605141060026", + "7/4371732289267430857", + "5/7457336369468105965", + "8/3059513941787256428", + "1/6261259591666713611", + "6/877359955137239176", + "8/4140985149034544978", + "8/8927622165873276898", + "3/1864468142749822983", + "7/743264538540616087", + "1/1230351264304094151", + "5/3592989807433612805", + "1/6301864452870363081", + "0/4031997518835642720", + "0/358380539472537950", + "9/6403971156614423659", + "9/6564852178917408329", + "3/1348145177801223503", + "2/6567653482708784962", + "3/5512438410739868243", + "6/1929328829708697756", + "1/2698399947757247001", + "7/778852268964031157", + "5/1553109660323820815", + "2/5606576375583840922", + "1/5309985966188515021", + "6/536847937283279326", + "8/4374856144880351368", + "0/2348528382860728540", + "1/7381799714037915231", + "6/5951381068241595436", + "9/1308656451840691339", + "5/9000730533707063325", + "9/4157508887049101829", + "4/6405470048325239794", + "5/8914592829936229765", + "2/3026578617142804982", + "5/5487826867511751215", + "6/2143671061859109706", + "9/3918144620361145059", + "5/4922801891149834995", + "1/8950072201986383951", + "6/1007874119190303936", + "8/5051585679926727928", + "9/5216798499745187819", + "0/6589476796076582770", + "2/5363921307373162102", + "2/4615248582158324122", + "2/6013393559398322172", + "6/8787565301729020026", + "8/3405131253449074758", + "8/6374394687269706758", + "2/8461396993232533562", + "5/574717111805950325", + "3/7562444953263949103", + "7/6622590983405748807", + "4/5479521126623014494", + "3/6775395444992863063", + "0/3679427845364521740", + "9/5459070290607007729", + "0/643940367494134630", + "4/5004303770765001684", + "2/7677588008804024632", + "4/7587336144599321074", + "7/103206367548666917", + "2/8714033237983024582", + "4/2589556697760145714", + "6/2215831843157804386", + "4/720582893869885644", + "0/1163149156109630540", + "1/7949478382700571371", + "2/7108410249690405152", + "2/4921719778543797742", + "3/8766221977801378523", + "6/4110820564476106136", + "4/1330412166301364504", + "5/7548009462455177735", + "0/345563583522419980", + "5/8091097875111713525", + "4/2231081349844619724", + "1/5478602410104597101", + "8/241848706740933648", + "7/3215962998547783007", + "0/8802163189837214020", + "4/2303678306780963804", + "5/765745563266945535", + "9/6235653485270694139", + "8/8652709584409208648", + "2/3558897802501051142", + "5/4605249490218942505", + "6/5550975342252500916", + "7/5258432391744652647", + "3/3427362232126273823", + "1/8491887150922436151", + "4/2515322958546858414", + "4/6363161414586125644", + "6/6918720416396626396", + "0/8324855682355746930", + "7/8168475534019789177", + "6/6451629083478091516", + "8/876393375895190448", + "2/4981644868190999182", + "4/6813732266138802024", + "7/5441695703649805627", + "6/7081278709992789106", + "0/1054990230631284690", + "5/6882919250922006945", + "8/5272891210772284628", + "3/7972288391333055973", + "5/5736421670225398175", + "9/5558850674166313239", + "6/1957686632285203576", + "5/4607750789674737255", + "2/320047844533056742", + "5/8924332306331340205", + "4/4669126319615355414", + "2/2392863452824280262", + "1/843891097873291681", + "2/3398892359852095082", + "4/1925930691719566914", + "5/8825906134506774525", + "8/8196140762409389148", + "7/8360075094071756157", + "1/4356948780525637941", + "3/6810942544358591593", + "4/8654672096102166254", + "7/167618900961630167", + "9/3584707345961013589", + "1/8043753293826466921", + "3/1515748007853837203", + "1/5120936178785484021", + "9/3789707065715517059", + "8/8893003742761477098", + "3/8396721553274572143", + "1/2129677048178831621", + "3/6712470536061135383", + "1/1772642229914027881", + "8/5117552499753534248", + "6/1621469686012348696", + "6/2803692313695012616", + "5/1290453350363699595", + "5/4303444634751922575", + "2/2787341911699833262", + "8/5888873057433772138", + "9/1198271236710734269", + "3/7592906988948093593", + "4/1420767026212678304", + "5/8731480702976740505", + "1/5497367913213514431", + "7/2774899634430421147", + "4/1819150750760271614", + "1/7747677407119997091", + "7/4123617359075585797", + "1/7026223050431745851", + "3/7790146585492226823", + "2/4306981254480921172", + "5/6000081993888696505", + "0/5267643633260029660", + "6/8788904763861909836", + "1/2176672401683248891", + "7/3090507406273800117", + "9/2492378180389726529", + "3/5919087748985541893", + "3/6038754210383974343", + "2/8888731159261262422", + "4/2498129228513019844", + "3/4474392028623326463", + "6/7850957500989930816", + "6/7648274722484611426", + "7/2364627772446858507", + "2/469576624933086052", + "1/7963759390042817451", + "9/8613631572628403819", + "5/4851265760560318445", + "6/8499549334747125956", + "2/8668604181775298142", + "3/4250916351467305983", + "1/7518600553195712971", + "5/4724574292448295405", + "7/5857894972559405207", + "5/5515292016365961105", + "6/298390581661841156", + "4/1379936808526569174", + "4/3591776592395398824", + "3/4413758190507918203", + "7/2678954200110114447", + "9/6108341795824345229", + "1/6408467424750676251", + "7/8469587705845160207", + "4/2576710871901316234", + "0/2036288364139591100", + "7/4524888762390440177", + "3/6361684564702115373", + "5/9112870925224711385", + "0/4017972466061235320", + "1/2412000209721734941", + "7/1967525747797518947", + "2/3950416373773321952", + "0/3777405962858583000", + "9/8060397078756704809", + "1/8384653570951533051", + "7/4777056092375273967", + "0/7270259389917633810", + "6/7583537169113090266", + "5/3976337878756675875", + "1/1024091685743800491", + "3/5551852023216327773", + "6/5826937695826603946", + "5/7684475878868651165", + "2/5243993199717884992", + "7/711117669338262707", + "2/1816950320191828622", + "6/6331313562913727336", + "0/6825055588441849410", + "8/3192134343534816628", + "7/4335356440674269817", + "6/3833568554217287656", + "9/915494314280623729", + "7/1557653509735398987", + "7/8510863879176779487", + "3/9167578222778152273", + "0/8125541455348828260", + "1/3257076063288302691", + "4/7084213414247331784", + "8/803046337227527478", + "8/7290242184337605778", + "2/4582452456001474792", + "2/4578532745544698252", + "9/4104479662162089449", + "9/5059653736681774369", + "8/1113822980571357758", + "4/8228961213094661464", + "3/140929002833075533", + "9/2243156373881173259", + "6/751181100241885956", + "2/3901075757987385912", + "4/3300102868489926654", + "3/4687954388558310823", + "9/8952007949894294859", + "3/4224488058671093503", + "6/7428411536380333896", + "0/4539104050221595010", + "4/386243996448728294", + "8/5979261136021236388", + "2/4935768397542009742", + "3/1060332037922612303", + "7/8858724294500530117", + "3/3428652350546452423", + "5/2597446306564249485", + "8/3590391703867411158", + "9/2302378409318610009", + "6/3861832267654628946", + "4/3205339589851963504", + "7/756178627639217907", + "0/7761370645462552990", + "3/7535457796884387173", + "3/1421447611127514463", + "7/1645492824602948397", + "0/832552074307715200", + "0/2820506325701287300", + "6/4265777835770452446", + "9/1891252978769396609", + "4/8202322493827924894", + "8/5802016892273525438", + "8/5925920343286976138", + "4/5294978393562727704", + "2/5001060709904160772", + "6/337749794093218266", + "3/5175715654123618513", + "9/1491372482247457269", + "5/7999176163560492455", + "2/7167689955208443502", + "6/7896543262548259036", + "3/7029355724872194333", + "5/1822719985272668515", + "2/6489885688909888212", + "1/8818702093196259691", + "0/8703001686662631920", + "9/7558414764477786509", + "6/3730592115841390116", + "6/5293185631868511046", + "8/7944094006503716678", + "0/6131877353824263450", + "2/7755458460076723022", + "6/2734357366768942676", + "9/123702526995621369", + "7/6098925454234828207", + "7/4937996678296405807", + "9/2438178718896170319", + "6/614307240158509966", + "6/3685086748072790426", + "5/3262856318884124765", + "8/3773007496128003578", + "2/6616378070566279432", + "2/5781043753764977842", + "1/3679865128460699451", + "2/3896690311597889932", + "6/3876093696181078996", + "7/2924717007457235247", + "6/609228756601328056", + "4/7932621472215696864", + "0/7676850743956999790", + "0/77926825851157580", + "8/2542962790475600248", + "3/4280673999031803173", + "8/3045734010737351368", + "3/424803040672360533", + "8/2518516417720717958", + "0/6807947309263532860", + "2/8750745396235730492", + "3/2405294150332644693", + "8/8716025404225961948", + "4/1865440874326808174", + "7/5728567582902274387", + "8/2920449962139492788", + "8/3393334139690549558", + "1/8878645883629898051", + "3/6735400019433760513", + "7/6078640853229892747", + "8/204685936719347078", + "3/2173882191287287973", + "0/557543092906025910", + "8/8015923251209631708", + "2/6687798513594630702", + "8/7392610169329067258", + "6/7061730602882483546", + "7/5611429494412351807", + "4/1088625571770347984", + "1/8961287400028426331", + "6/8948016012905208626", + "0/5568033024168433440", + "5/3237154070470632575", + "0/8662334761132098050", + "9/6383288359317283079", + "8/1350898358994297678", + "6/1894832808248941886", + "4/1615067168346010564", + "8/6115059152409314808", + "3/9073171352793703403", + "6/4782341151038343676", + "1/7982134758346735171", + "1/7471784842568411511", + "0/899175114290277630", + "5/8537597457640397745", + "4/890540361341764744", + "1/2857146870294975601", + "4/1214415158703584144", + "4/7100252542971187364", + "2/5823655032615875612", + "3/611522516428915833", + "4/421450021181577604", + "0/2863022553050348660", + "3/1726383403948631923", + "9/1874497077094804239", + "3/6282479406020274673", + "4/6086473750063161644", + "6/862406935925477956", + "9/7421953096773127689", + "0/3420340460205253360", + "2/2859360223227588472", + "8/3592106334595194488", + "4/8241372033020757794", + "1/1993927465066194981", + "0/3249323346357999770", + "1/3304595581153690651", + "6/3751647365778694216", + "2/1297111311780761872", + "1/232181091833125731", + "4/2759104301105855464", + "3/2179608483405829623", + "1/7521735594132563001", + "1/6904704347334266511", + "4/2679580053945281514", + "8/99509871088504618", + "1/4882345346959248191", + "5/9066317439085042265", + "1/3791999351619603141", + "6/2440047399494716156", + "3/8608402951152839563", + "2/6872024579834071812", + "0/8996604867263790800", + "3/8016213920852619463", + "8/18809225954186638", + "5/1133867141076447905", + "8/820878972398632328", + "7/4682340036567578907", + "8/4219360078924715298", + "9/4456586612107755609", + "5/4604408475005145105", + "6/2729080964393409736", + "8/2393197126678007048", + "7/8620446890337946547", + "2/2291176475620534602", + "5/152194356532556365", + "9/3813031094800015879", + "1/8124279916985855541", + "5/1752823520898611725", + "9/7745547890668861339", + "2/1195607892704345822", + "2/8930170258662326922", + "0/2947847154051195470", + "1/507428325694130401", + "4/670627483968324134", + "2/2323696964544894012", + "3/9110325048752364583", + "8/2469689865204046238", + "4/6755015732946048624", + "7/4183371216670817507", + "9/6949817836906934309", + "3/6000720520784890733", + "4/2906611932107334544", + "5/2942198118445381225", + "7/7611993569620641767", + "5/2492810665731128865", + "9/7525491576707522269", + "9/9074000827306474119", + "9/2941209749783294199", + "0/4499744572796739710", + "8/3818469494048301898", + "3/4614363096208988503", + "8/7926812379467790398", + "1/7742479250397507441", + "3/4326142117361364913", + "0/72147510554833510", + "3/1324774703818762023", + "0/1481930600812860010", + "7/304706673322811217", + "3/4306933966373216673", + "6/7756128295335027976", + "2/982459988694164202", + "5/6120747397795702175", + "6/6347895228059465696", + "8/2637144302472469448", + "9/3943434221127928489", + "1/4360318924521426071", + "5/2343131210482928185", + "8/4668067667426281748", + "8/4942772865999612628", + "1/8407170516812719231", + "3/2245350602659207533", + "5/4739340573082047205", + "4/3458308737906276654", + "3/1167869042355662443", + "5/5387356868836958125", + "9/3710592616763486589", + "7/5198359160002068797", + "7/1316465203731685087", + "4/3176796226194638114", + "8/1746266356809344468", + "2/8162836453019997012", + "8/1365101476049029418", + "7/5284427820181663887", + "3/1699036783016937613", + "9/8165997745332986639", + "4/8009122004212655614", + "3/3795995748269893103", + "1/3285000873683184071", + "8/2183785513278266008", + "0/4695100778877642180", + "5/6770489711774126325", + "3/4780043027367489683", + "9/6620987462383596669", + "7/448405130221497117", + "1/1374644757487551351", + "1/6053480691948476261", + "4/8383955103242914294", + "8/5191522830607892298", + "3/5257159256072657703", + "9/4609428067927713779", + "0/7352807370309723490", + "3/8181200725238065943", + "4/3569527460343298464", + "5/7003114517603992805", + "3/2821896593448370223", + "3/5128221828209122063", + "0/1742071160146330470", + "4/1921458465061714444", + "5/2478566354116729085", + "7/2179984061639500387", + "7/2849031253588506607", + "8/1797243928908475568", + "2/2628698222865354312", + "0/4308883545939832200", + "8/1647222673052534058", + "9/8954397882785748629", + "6/5699753787542344376", + "8/1042071595948307898", + "0/6218137938559459280", + "6/891374784736587466", + "0/8067542465111648760", + "0/1094769317607100660", + "2/7144245072362960062", + "7/2537308324409365657", + "7/1330422229959245447", + "9/8091558534956777289", + "0/3211452901494305850", + "0/6191522870031927060", + "0/653919957510492670", + "8/8990711456445380208", + "2/3906538436258916502", + "9/1187201959238218689", + "1/7377970606549247911", + "8/5349708031002390408", + "0/5315869691386942150", + "8/6733270477505368008", + "3/4532776261430191433", + "2/2544289657183712102", + "6/7106014450214944706", + "1/8026730495184343161", + "7/2632756318618985237", + "1/4006466744373691681", + "4/3297370904419012394", + "6/982681149074590146", + "6/1714141237582573626", + "8/177126263812582718", + "2/8166888809842957852", + "3/1991305752411748703", + "1/8968293758330659061", + "2/1126370395506195482", + "6/2747713364018190096", + "7/2326669202497061297", + "4/5847706480553909204", + "2/8087977535576956052", + "3/6816613200692252333", + "3/6037352517894001683", + "8/1452342083116440228", + "2/7605040468952058872", + "0/3456572857822945460", + "1/633964156532929321", + "0/8584665334663904430", + "4/6566149738642329014", + "5/2681015857141695065", + "1/3194620038194304371", + "2/8036634785346655152", + "1/4435539377019562341", + "6/104083846463653296", + "6/4830251744726051586", + "9/2938896905282788889", + "6/6663496309868691656", + "2/3315596700176484522", + "7/7238208118765380027", + "8/302931118282531688", + "4/4025785225347738114", + "5/6968135493788179255", + "2/3736702224938149022", + "0/7883639234543643980", + "1/2561005392093649431", + "6/1619727771132196", + "6/455278535750640176", + "1/361703629843041611", + "7/73892997665943137", + "9/7695488511771006859", + "5/3923264893621553965", + "5/4575408269977176325", + "8/1930668346914029368", + "7/6032557644304970147", + "3/4850584457507960753", + "4/6948277763870990364", + "5/4818958880325362445", + "4/1923360748803707764", + "2/6912880592228977462", + "5/978312702479422675", + "6/6886631484026934796", + "4/308016549500254364", + "5/8756409907706985085", + "7/7901734364760318407", + "9/7825330208557115759", + "0/5011954836132549700", + "5/2917018918775611705", + "5/4581482374647820365", + "5/6821114147187091385", + "8/4331594034128865768", + "9/3422027655434034429", + "0/5128321562761302110", + "6/3714500124964748866", + "1/1324200822593727591", + "8/3840787256163660158", + "1/1553435205695252201", + "0/1505326246213873410", + "5/7818333083648173685", + "5/4691862570716278115", + "4/6980262620923349184", + "6/3614567971831984596", + "3/4895206470943262623", + "2/7158581298979753692", + "1/1881237446247271541", + "0/5140148812800158510", + "0/5276826661895662340", + "8/4564313267507135358", + "2/4774745053400346812", + "5/4544304499696152595", + "6/8790910723500286646", + "8/5238752242489307788", + "4/4166765201001939174", + "2/4470527938291752382", + "9/5492737888134396629", + "9/6837833406876233909", + "7/1562631876697279097", + "8/1738641036506718258", + "9/125308318973788789", + "0/1891594019375082730", + "7/8856191226742931157", + "3/360826282443203173", + "5/1596187325661735145", + "0/5772652057330389900", + "3/3837039186147411963", + "4/7382636977460649564", + "0/2870022636887407510", + "8/899930907096553228", + "9/341480619401308239", + "3/4466978418259763483", + "4/1163744452041068544", + "2/4991506006748497432", + "2/7609687257635001482", + "2/8865341321488399942", + "3/2106930114676888033", + "0/7361430858099851290", + "0/4434820371983992720", + "7/8493640469981656107", + "7/4792342067566126457", + "2/2896781016698752572", + "7/7962157309484740497", + "9/2325320397705585069", + "2/1027141227201743722", + "5/6181713042752985595", + "3/6723068652267476603", + "0/8503184837441320660", + "0/3315598912248432960", + "0/8761204452479506940", + "3/8758294952013159193", + "4/6562447960590230054", + "0/4905526019883109910", + "5/1802261861530848855", + "6/4248923061897066066", + "1/3477883455646406011", + "6/8258022817572671866", + "0/6992076077809456630", + "7/4293719228625953047", + "3/347664028893505503", + "4/1357431609893521024", + "4/6931493768044660854", + "7/8531198171169802667", + "7/2475529140890849617", + "1/2081763687953869531", + "0/1182977127377327180", + "4/7252217922294658214", + "1/545752383321418481", + "3/2901031427745908863", + "7/8280250341563906327", + "8/6417840953297251188", + "0/6837057087994983790", + "9/6143797614477186919", + "1/3004438599090196761", + "0/2928686519363970350", + "1/5837583756588472781", + "7/1578519780990340007", + "7/6074094537782176407", + "3/8480201958172487163", + "1/2356386726514772941", + "3/6313158990697989213", + "7/6558758684957718217", + "0/3994707169426685940", + "6/6674575345034067596", + "9/4758057940417596609", + "9/250115788899874759", + "3/9045188123814655173", + "0/3278190700028945380", + "5/6003909188506167085", + "0/676446068572630450", + "8/3835485699295998938", + "2/358418362623697102", + "8/6753399157804930078", + "0/5564247961016097910", + "2/4307695340761397762", + "3/3428571247776109603", + "6/8587883906979589346", + "4/7741186907561305514", + "7/7134112675472447267", + "0/6695851703840603160", + "7/9034142637439526157", + "8/4237135240793970388", + "8/1124001505858808228", + "9/8154044628199712539", + "1/2142392856728163321", + "5/7308848342781355935", + "6/7281199878162501306", + "0/2876920112795465120", + "7/4719673542266804677", + "5/668058030819763685", + "9/2836921880045946469", + "3/2476614896954632403", + "9/6924593718279977209", + "4/8102987591353279794", + "6/2322393006208016096", + "1/4492206211071919171", + "2/7155933437557040652", + "2/3650773268181865642", + "0/6433997999306839210", + "5/59312523474362225", + "6/5311081761986306446", + "3/5812376474912461403", + "2/8016903333732566862", + "8/6925856509068534128", + "1/7000285852793615681", + "0/1208641142848445030", + "7/9069717160838494387", + "7/4229677240903762117", + "8/1226551931970166128", + "5/257874069776736215", + "9/7942581475111943519", + "4/7287687206372684614", + "3/1610273395050230973", + "6/6426496344339447436", + "0/7741650732767885480", + "6/6336741150006080846", + "7/7211962420288030717", + "0/8676391309747754780", + "5/726092182314964565", + "6/1160783419817403606", + "9/4019060573068348039", + "3/6129630079615679413", + "4/1117324638092061774", + "3/8568135705504126663", + "4/5615371321837056324", + "1/6637465496852583151", + "6/7246448030075953336", + "1/8023475493372900331", + "2/8853962289119748302", + "9/713256505751234479", + "3/97035625898389623", + "5/4209628906344933275", + "7/5248654605167672387", + "7/3678260627439792087", + "9/5056679099815267779", + "2/5182478843488378582", + "1/4377347127156968051", + "6/6635824768447080716", + "3/6987101957384148303", + "1/5255649376418142591", + "3/1420680087979130883", + "0/4541661979292429260", + "4/2984136409551243744", + "7/1232447729192267277", + "4/4072867620425461544", + "3/654121199487383363", + "6/4135359545282398546", + "6/2961766860896757826", + "6/1776828297949291906", + "0/2476536722195522020", + "1/2965163607045625291", + "1/6373897191363331861", + "4/3378031604758102524", + "7/2826671014422644457", + "5/4259844274850388665", + "2/1222185813406301302", + "5/7755615164931500025", + "1/5446526286182757981", + "1/5946036918249233531", + "2/2649833030141553612", + "8/1289954182482089208", + "6/6044593792542242936", + "4/429888086316415594", + "5/5529422519714734325", + "4/8450961068176067064", + "4/2601024431255955204", + "6/3220520921464419716", + "4/5986125334033160824", + "5/2865518998818733885", + "7/5483013374646365307", + "3/8815487538704534553", + "5/2495393152017985705", + "0/2153306245023308350", + "5/8783542745945818445", + "1/8801762025343376001", + "6/5589512065920903446", + "6/5888646059452326496", + "8/9148803446193891908", + "0/1052085118322297010", + "2/9056941920805620502", + "9/7185288406454968959", + "3/6609141258291937953", + "1/8088760010814361871", + "4/6040016699399544514", + "8/2944234745364811118", + "9/2542611477611055899", + "0/416400926465403190", + "7/219348211883382827", + "4/4615693772805331854", + "6/4573324042483505146", + "4/5323975219609125884", + "1/2257951961505802551", + "8/8477433381594629588", + "8/3987019081806761398", + "8/8378220260825390588", + "9/8683429959883008169", + "5/5152657487684857925", + "8/6540077155349764478", + "2/4406466244032977282", + "4/3688917147358425914", + "3/8740912660242809553", + "7/2452939211652804017", + "5/5324114501373265115", + "0/395917653879499590", + "1/928562766937087941", + "9/1326296362535049699", + "8/5803812432895942538", + "7/7560793021565864917", + "5/9197559596029901955", + "1/752206033499138421", + "5/2713974094194593585", + "7/9214411016370517477", + "4/181501812366880494", + "9/846106310419276739", + "1/7670972561957042601", + "7/6267206042347047197", + "9/1137441052038023029", + "0/6336421179279448380", + "4/2853914823263748784", + "8/8563973439277608978", + "2/4284894537432185732", + "9/5088332235915482559", + "7/8326734325758961337", + "7/4709928054918724267", + "6/619108808472811516", + "7/6443212708233366967", + "0/8857664845776689800", + "1/4883353536191977991", + "3/6476415938957962633", + "8/1563551224442644008", + "5/3001119755655036285", + "4/3193246921111492914", + "5/8551833531961465885", + "0/6367021567557281020", + "5/7824425634044511275", + "9/4229435630126647459", + "3/2976123586979151473", + "4/8287683889410939194", + "9/6375451357565235809", + "8/1182180800237853308", + "8/8894207403797181608", + "9/4621076116938053729", + "8/7672070532209536338", + "6/6455758149727920476", + "6/800967081866261956", + "8/5797272304598611408", + "9/9214188650413675449", + "2/4656027031299733692", + "6/1846287398346306336", + "3/7164303450729938133", + "4/3140330465309469304", + "5/2802171479466557265", + "8/9132632987711146758", + "7/8682026234508631487", + "0/1480997811996769960", + "6/2423122048958302866", + "8/7744424987363994468", + "7/1850189607069988037", + "3/4443534317825729653", + "8/696443323420237398", + "3/6686126289280304083", + "7/786902004895618247", + "3/5246353289513024673", + "0/2731625269158914030", + "2/7362462614953598532", + "8/5805122949638673178", + "6/3876442191714923996", + "2/4330399485866046392", + "8/2464552546540696128", + "5/6174125245342065065", + "8/2402387514073425588", + "8/3834063851858217978", + "1/3012098176244451101", + "4/7076019118369987374", + "2/7917943121816536692", + "0/1830737733009568010", + "6/1106825368522633906", + "1/299966347030478951", + "7/6944164616272163077", + "0/8644462100974342710", + "6/7937801517364807166", + "9/8005180700994099529", + "0/8712287634885785010", + "1/7665680893040044541", + "1/2469389763710324431", + "9/1597693030140069369", + "7/1562559798288763757", + "0/4774020387355159990", + "1/8809540013739228631", + "3/1418026253792936913", + "1/7658140166203647911", + "8/6256217486340176368", + "2/6564024493448091662", + "8/8911826104761921618", + "0/2421648729184605960", + "9/8273171853074724129", + "3/5464490792055094103", + "4/8202516452963305844", + "5/4237326580663264885", + "6/2474422536847877026", + "9/2946854958386009349", + "9/4972708202658697089", + "3/8082340156973858383", + "9/4002526854958849279", + "7/3557861744347677047", + "2/8795554192634359092", + "9/352786349244459809", + "8/7994083516040553008", + "9/6015289663680153309", + "6/8007455773227362666", + "0/1069131680353162170", + "1/8572643371774061641", + "9/5083446443064735049", + "0/4512035646585254570", + "8/7684133274158274388", + "2/277394234064294662", + "9/8055131480871221189", + "9/7070510690403614039", + "2/2303365712814093752", + "0/3014427153772381070", + "9/2486905002130737019", + "0/8169363132395977300", + "0/3083861286786672640", + "9/6379082814816180199", + "1/1727495495037525271", + "5/6485850038781173505", + "8/5247819082719237038", + "5/9085856422245907075", + "2/4764291500099860702", + "3/3794102652856439953", + "9/5587156735571981089", + "2/8650526491329306152", + "2/5537801437198417302", + "5/2571308322050021165", + "9/4897982122583152209", + "0/2332871169888526680", + "4/2686111172578165364", + "0/6224343331096651320", + "9/8732915486626078579", + "4/8727414523026999034", + "2/5285613892080024152", + "4/4464811440145937054", + "6/759981463629515636", + "2/9123425628260544882", + "7/7860658028505071527", + "2/4063222076047109752", + "8/1682137337435337278", + "0/928566077423564870", + "7/1538062269854027607", + "1/4838557717083732971", + "1/2321412400115669481", + "4/6022784458460749044", + "6/4776362659339057406", + "7/5892712637564781187", + "0/5737662982571412510", + "3/2394726150214010383", + "8/6120717735691646888", + "4/1357553183842770384", + "2/8278751952745244642", + "7/8153265108221190897", + "3/4933478127395617853", + "6/6441992977738473846", + "5/7957753600389701325", + "6/8026201358906142526", + "0/2844232584388146190", + "7/4493977927170549647", + "7/3976716911050222917", + "1/261920665564106181", + "6/3245623509351827946", + "8/4930921562547126478", + "6/3480439687644690976", + "0/6208651048303545550", + "4/3675808841972851904", + "0/5023207705574688530", + "4/2229033904446964734", + "5/5059379111593259445", + "5/4211505410143078255", + "8/707834473819858218", + "9/5999151639856838319", + "4/6480879943499945914", + "1/3844742718198866321", + "3/3224208604106063503", + "0/6454837718794724090", + "3/9057779957839848993", + "3/5805439166977441693", + "5/4127744670814663345", + "5/5658556357865644565", + "6/8321044860716195656", + "2/1525408748696674952", + "5/588975810108361385", + "8/9193997452982740708", + "7/6096374991071868667", + "1/721898287848381511", + "9/5408734040059943929", + "6/7302393513743356816", + "2/8208131451573167262", + "5/8494967574395793605", + "6/3704872366837478956", + "5/5177034664611963075", + "4/8591778770462510534", + "4/4393199761063374914", + "0/7867645707095002640", + "0/8317206228696999570", + "1/3517654935929737331", + "9/7733268785057961119", + "0/67914400118624960", + "5/333231574395362085", + "6/2559688824771243036", + "2/9212177344045717362", + "9/3254852644650450669", + "5/4239312685834213005", + "0/1367759809654934240", + "8/3760833951013473648", + "1/1417210787874182931", + "9/2061916593163484299", + "1/6402089046651318481", + "2/2452883737703436522", + "8/69956135420498558", + "2/835359884374540462", + "9/4873497324529554159", + "5/7036730761393037185", + "3/4067931640654570003", + "1/5565201586005995911", + "1/7518098846710786431", + "2/7422079335257068762", + "7/3404490400099647277", + "1/6200936985459113441", + "6/539931213031936656", + "3/5006361672693158133", + "6/1273025275582045636", + "8/1593991052613407408", + "7/5358436243058806477", + "6/4121135698521219216", + "9/4993017647699181249", + "7/272952601946739177", + "3/7063605985846675843", + "5/6896499577398390625", + "0/108074582846007280", + "4/6720774836290583384", + "2/1989679761024725392", + "1/6229148826794015401", + "5/1584565591222609975", + "1/5853062442057935571", + "4/2310505206160112434", + "4/5583408578483755974", + "5/472576464446186925", + "0/7606804739301243480", + "8/1116943967067500038", + "9/1190977734841671799", + "1/7810098897817386061", + "0/4054662229917871080", + "9/2598099686130759029", + "8/6553813050496390128", + "4/659329376418505634", + "7/5932324026460831927", + "9/6581912274137598049", + "8/2060346647618328178", + "3/971942182065411463", + "2/759528467112785452", + "9/5508709001248505599", + "1/1494754633642073491", + "2/7826968640656946842", + "2/5408195517211826802", + "3/1255821982775001373", + "6/8791653288180724236", + "0/465233653445958740", + "3/453005548423575233", + "0/5711191903018636410", + "2/3938418644320491122", + "4/6549002942155254534", + "4/8284625807371769184", + "5/4614201587041350625", + "5/3929038087574881495", + "2/6638281238689435932", + "2/7446641155049495732", + "2/1174997979504657602", + "6/8233652468898765056", + "1/4924763891151206171", + "1/7463384224808695041", + "8/6607668966067396788", + "8/7464638944763495978", + "2/4436706706077304612", + "7/6873445370147866877", + "8/1239104672844137908", + "5/2702150542655057165", + "2/7902397444255639382", + "8/5655860448952116028", + "7/2712515312536935847", + "7/3126291427755749797", + "1/1392071743822559311", + "5/2761884953622735045", + "6/923553318313427556", + "3/8445919348955373063", + "6/5688537402056141556", + "1/4963797020105044721", + "3/6243985023460516353", + "9/3443694179947150619", + "9/2989897541428507269", + "7/116550350665526057", + "1/5382199802193836041", + "5/4434500587603106745", + "6/7448127469256435826", + "7/4444689493617863247", + "0/8626095268350027320", + "4/139527579193536874", + "2/341229219249260062", + "7/8116371410427167207", + "4/5413774941096058124", + "5/1927734155158373105", + "7/4178372727600568507", + "8/5275831286123016058", + "9/8275073087586271839", + "7/505244405396275357", + "3/1308282546170964273", + "8/8093757860840262918", + "9/1982997705763137989", + "3/129054093384319613", + "8/3549517283033646488", + "4/7138281883622192894", + "6/4793184027888702486", + "2/4813015553977729402", + "3/7558325973677653833", + "7/4655710601805040517", + "3/2774647057335343563", + "8/5261780403671340928", + "8/7096410625298216788", + "8/5161866810007250038", + "6/7789300072348545026", + "4/5525910920880612574", + "2/6014394015862468112", + "4/8278446515690313974", + "4/982117949095567334", + "1/709361810424054111", + "2/725391449293818402", + "1/433328974784645451", + "7/625965360274791607", + "2/5735204273530887482", + "4/5500890399258653754", + "3/1699634147282292793", + "0/5559749960132720480", + "2/1721069361132682462", + "6/2889438391774379426", + "7/9025343325554081337", + "4/1755734736856581944", + "5/2647076492697411045", + "7/758708411454677827", + "6/1980193658215847166", + "2/7767057188371998412", + "6/2957554937314433856", + "3/6822878015107633153", + "0/367929007626953940", + "2/6811288846411424432", + "7/1405605546747008137", + "1/6474197563450316941", + "5/5513852353362961635", + "1/6880165794060534311", + "3/1373152996540211353", + "7/4175309191379951557", + "5/1124366149893063665", + "1/4090272821872205561", + "1/2870565491315764191", + "6/3651232215180350066", + "3/6816780302182339833", + "5/1404928100171926815", + "5/3398158305625371585", + "6/1173659754092044506", + "6/1847107523763299766", + "4/623932937153560984", + "5/3810331789140828075", + "2/4994256342608708922", + "3/8572718001107500203", + "7/6020011643258343157", + "8/7138810283392218888", + "0/1178049153311943070", + "4/7220840217069940664", + "1/8380984962060950241", + "4/6610695460070356514", + "6/4593615099879103386", + "2/6318940155021448332", + "6/4877795253170227246", + "8/5625360725370123538", + "2/5245073057752601772", + "3/5032252693744548583", + "6/3007141350068721946", + "4/3495917509753266214", + "9/6804315006642362859", + "4/8000884384075299894", + "1/5562264197470237381", + "9/3333927373644809789", + "1/5729914957193412161", + "4/7508051830360446354", + "9/2549537235226625569", + "0/6624113543034085780", + "1/3617118845089874551", + "2/8323265279392390192", + "9/6920936771913448259", + "1/2332994901080222631", + "1/4517416736116217091", + "2/8325441093063724742", + "5/6148957102328479275", + "5/1737793948849299245", + "2/339720302394375012", + "3/8428027223154511153", + "2/255779912685977022", + "3/8541548153190708443", + "7/2019298367151599677", + "1/6738352374379108541", + "0/2820818233280300960", + "2/250197765428340922", + "3/3901765515833253253", + "5/3103813464864175995", + "7/524980998480594867", + "0/6678809849236299320", + "4/2625485762096367224", + "1/5923459805139931401", + "0/8092906610655828850", + "4/809815191355712414", + "6/4023461282002153916", + "9/1199186984216641909", + "2/7496781274290786892", + "7/8492495882014730167", + "3/8100842010897213253", + "3/8090874503940925283", + "5/1219903755543174995", + "0/5921897427008875840", + "0/8337261575460534540", + "4/7581212209526173754", + "9/7049359462515537049", + "9/6810466771697797779", + "6/8306085745644326776", + "1/1881144770288923041", + "2/1289432292170167862", + "7/2449982284311914347", + "9/4371526449269825979", + "3/8980643233572815243", + "3/7552749586360909333", + "6/6030508088316275376", + "5/7840800216345376565", + "6/7348407556282005666", + "7/2894701758886328357", + "0/5316154064835921940", + "6/2051467543140615396", + "6/2552817628858553206", + "8/2195134589017711488", + "5/8885794078974632045", + "0/6533467058439050670", + "1/1706796102516243771", + "9/5218246640138201709", + "8/7830687480118093038", + "7/1391410003824850567", + "3/8954956133197020983", + "8/7188944132175738748", + "1/5917874226658439761", + "0/7083065615240199830", + "7/2636359049616758927", + "5/1866740041768931335", + "1/4170417758903980411", + "9/7864349982754035179", + "9/3717954214079918659", + "2/2715458165974974592", + "6/2798265305047434756", + "6/3875557725906455436", + "4/1235063538098182424", + "3/3731001811847628653", + "1/7084653422743236531", + "7/5541204715443364337", + "7/5197183485665071267", + "1/7830062573508498821", + "7/8871319427327393627", + "8/7469182527331058718", + "1/111400007359129091", + "2/8042868129841603412", + "4/2889287675568576104", + "6/5971128111071205806", + "8/5326639845218384718", + "1/6306378982101442801", + "7/674322707158935867", + "8/8218011240909808288", + "3/3067628465261233723", + "1/5587717238340115401", + "1/1861092211859222111", + "6/5067661545710517816", + "6/7236380593327694156", + "6/4183515887346955776", + "1/4753460352745319831", + "0/4450128667582363020", + "3/6623301005850812233", + "3/8992470295804906303", + "0/4004436873857807090", + "1/5238584629256071341", + "7/3609430347953103917", + "3/4398947929552025793", + "2/2415058977317263272", + "5/6196678968096962775", + "9/681317204105963079", + "6/1867690807119275066", + "1/3177907144939452471", + "1/1724668430970159481", + "3/2898916308724968423", + "4/4794822077193158214", + "8/1127632850218300238", + "3/4531998551442934313", + "3/7132934392320859123", + "6/143632378543288186", + "0/529955899139820130", + "5/7678103379787062935", + "4/3809448106797971884", + "7/1421066311370106217", + "8/3274072592698772538", + "8/542909365421238", + "0/9143407152180245470", + "3/4419415376096835873", + "5/1332401009159433575", + "8/832524819209515798", + "6/132125511756580106", + "3/2362703425934229423", + "5/3217782787472964825", + "7/1240215154903398527", + "5/5077133732883027885", + "0/8450749031879929410", + "4/4492960262524923424", + "9/8440631568979156719", + "3/2417141131720310713", + "8/5753330215010967528", + "1/3901474519974299221", + "7/6095681405769516947", + "8/6014444902720608428", + "6/8298588078012757486", + "9/5045798674263790889", + "9/6850459233364177669", + "0/8373028569797048100", + "1/6337349721818756871", + "5/6690495103439379625", + "5/8989885039722904595", + "5/7584894832371194195", + "1/8987941979353575851", + "2/2935630532140297742", + "7/6440798160090520507", + "4/5306079829291212184", + "9/7077975619301782509", + "0/5451545064163194160", + "6/4951473314331317186", + "1/7048127217568365451", + "7/2282889993046535777", + "1/949839330949353661", + "5/1103702363339546915", + "8/6195851059638053398", + "1/1306414790136876321", + "4/6193651908993888824", + "4/3446396561333608514", + "5/8382842097767822465", + "7/8511767851524538647", + "9/555525349023348109", + "0/5628707354432612020", + "8/3493566071169604868", + "6/6970746157478525556", + "2/3212512770763774692", + "4/3558546714509011804", + "8/3893101435404719008", + "4/5781033577208828084", + "6/768471707416200296", + "6/8157490861927429696", + "3/8085702697736096033", + "6/1124968517867730606", + "5/1977855559648316855", + "1/7262196385906332731", + "2/6085639806438199632", + "4/4070086065653274524", + "6/6151597244604639476", + "6/7013561344015048956", + "4/975811635341609304", + "9/8097777766471554319", + "6/2552734213960573306", + "2/7751363221716228372", + "3/8655966512778755363", + "5/5800224141316231125", + "6/5826278938722677646", + "7/4296977833341991587", + "1/2732453869327955061", + "0/610829115799168820", + "3/780789087026681073", + "0/7988769748039650910", + "6/6420009153169536466", + "3/7245470456694992623", + "9/2348863327928769759", + "4/8242217938580222144", + "7/5072909369972520047", + "2/8012591174532230632", + "6/2784629076637749556", + "8/675644996326245348", + "2/1499601931380250192", + "0/1975849785064535830", + "2/749579824447787972", + "3/1881286053082663773", + "6/379091422413369016", + "4/5303062218074694754", + "1/3850941047512203011", + "1/2265855161266097681", + "2/7063627941009170902", + "5/923853920966009005", + "9/927385936080810469", + "9/8933947460987390069", + "4/7033233521012343784", + "8/3466560192773844048", + "5/2716996290218539835", + "8/119379071557588638", + "4/4177056044747669984", + "2/1045847649334335362", + "3/5522915515125385003", + "0/6517514470021882100", + "5/4312372285374525855", + "8/6048511885746898288", + "8/2315460568681060208", + "8/8119787702625733228", + "9/66371020032727869", + "9/6325728937715111879", + "1/1581629707882210201", + "7/3199361337142473477", + "6/2425080442981981276", + "1/1762150802117448871", + "6/1374591549707496566", + "3/2455807521054754063", + "1/5058729732387507341", + "0/13958053289047560", + "0/4251029968316135200", + "1/1677109879729212491", + "4/3283920330556289294", + "6/8145806657366088316", + "5/7437238158851589875", + "7/7457564416782577987", + "9/2107630002290268599", + "6/2056434294011058006", + "0/1126107596863674730", + "9/8644948746887826429", + "3/6703137316210223043", + "6/456384370934383216", + "7/3096321177420802227", + "8/8809923075487854008", + "5/5702636341266013815", + "8/1594818620592148958", + "8/5536031782276048148", + "9/3360228738855139479", + "5/6867883194346259435", + "4/1333105547455797864", + "6/1587676473706037786", + "4/7009759717195130374", + "3/7942393315372211883", + "6/6791850565924074886", + "6/3965644804505691686", + "2/4926989997144177842", + "3/5767001413510675963", + "3/682075445492381493", + "5/2872099866491130635", + "9/153301345427922179", + "7/995207997983654967", + "3/5451989996205050163", + "6/2162052262801946516", + "4/6734543798194022014", + "5/4838912104458469695", + "0/8211223810014835110", + "3/8568924610154735883", + "0/1430424448185754630", + "3/2396374806967758533", + "4/4432516223845087434", + "5/361567071320367255", + "8/5289017820741709168", + "3/3011133751710979243", + "4/8737815739488266554", + "6/5238349381662228876", + "3/7489833482569711853", + "6/7015676332224061786", + "8/830487562275099558", + "9/4748341444564259709", + "0/6466415645121865450", + "8/5892168971538546328", + "8/8488407272463654248", + "5/4385255082537335895", + "1/1886494620924570961", + "7/7564876263420695037", + "6/4408975108973295846", + "5/2973989482621088935", + "4/7114983652684321474", + "2/2824621868302274922", + "0/8054700857409232470", + "3/974657910986238303", + "4/3713336235540478954", + "4/1721790012939033454", + "7/5123264782288536667", + "1/123079904757791821", + "1/6513093663347850081", + "7/3503048188506196397", + "9/5624699639073578659", + "0/691988672774479560", + "5/4878139946750718545", + "1/5518261461494463001", + "9/7623815043517867279", + "1/7098276925021196981", + "1/772232194645578801", + "4/1290718212267618574", + "4/3336692514863502924", + "2/3553661623190236162", + "5/88515463061949435", + "8/2929297841296394448", + "7/6064127175473404837", + "5/7780076042514317315", + "5/6097863030661385845", + "2/5517514554288168232", + "5/2408037004103145425", + "2/6146653747929837122", + "6/9105071270237228836", + "0/3945990285777060300", + "1/243470594541917381", + "8/7317493949104611028", + "7/8607788841229897727", + "7/6538322174767732287", + "8/3401025573293395578", + "4/1527162385508417974", + "1/9086033697753591041", + "7/4678771896566913767", + "3/2134821284017688653", + "2/5581846018284472232", + "3/8824739006839391253", + "3/5883986458375682513", + "7/1246803187295569807", + "6/5014217536857243106", + "1/330383399685245251", + "6/6721305365660396636", + "5/6425270577418507705", + "6/2126971527351581906", + "9/1411284349273289819", + "3/3144024577774710973", + "6/4463774282279540756", + "6/5863655170687154916", + "4/2557325449913379764", + "1/111256495922937941", + "8/7520628969116844348", + "0/4756595039805149720", + "7/9179873404131259967", + "0/6712162018599777630", + "9/8070847950445761409", + "9/5761479399946854179", + "4/2671182197303150984", + "0/6268534403507405240", + "6/5381350544825283536", + "6/4456233028156184596", + "0/6914415337401828930", + "4/8023861737236545414", + "5/6210119206888473105", + "4/1600656686511433694", + "7/7924814509898662677", + "6/3859766820123022036", + "5/2488143266511751305", + "1/8344833107682387191", + "3/473850194802898723", + "3/8371193472106014183", + "6/628254286940395136", + "2/1218464266670370992", + "2/7733110530979185782", + "5/577712379864945865", + "2/3673424039407095892", + "8/236525911787854448", + "8/5555824237077513698", + "5/3738261462210702795", + "9/4458399137121242909", + "0/3223497478111195610", + "5/4815956493410374445", + "7/2454237346764965137", + "6/3401553006779611336", + "8/6462476098504909628", + "7/8653690004453280557", + "2/3609856943096327432", + "1/1968822047024548701", + "7/3269514717700892487", + "3/4574318620109004873", + "6/4888280137647879066", + "1/2755255047439220171", + "6/3143645857460810566", + "4/2936810763181299754", + "7/8280122398431780557", + "1/387059280783045661", + "0/3340203211501676540", + "9/8803218983614801239", + "2/5482227222870499762", + "8/2878304132989896368", + "8/627558541987835968", + "6/2864271951901824406", + "6/6822907835643867396", + "0/7568351018788172520", + "5/631741632881919775", + "0/4883178972869056480", + "8/7875541657205770278", + "2/8960374251596696662", + "7/6408635462675232677", + "0/2535028045620318960", + "0/9200623882536262490", + "5/5471154108753371845", + "8/8259921920047715868", + "1/6651709141255288181", + "7/8686567935077986837", + "1/3314509704403681651", + "3/2786822104538631453", + "3/7419276271996465953", + "3/7718536980439513073", + "5/8519806120600142805", + "6/729764543954149466", + "7/646085087063994097", + "1/2510518267862819311", + "9/8133032590481372559", + "5/2934024186370237335", + "7/5759500972715424337", + "2/8987002399062214942", + "6/5245950681462338846", + "7/2663774224675899447", + "2/897055308016873762", + "2/2751830752453072472", + "4/2990869328193205914", + "7/5456154778222003457", + "3/7578097056910117623", + "5/1506277472473474915", + "5/1893609424769848855", + "9/4792169208606208219", + "7/1180173737017484147", + "8/3505441035091145968", + "9/1050089539956784789", + "8/4215198346427232088", + "7/5259927902280274437", + "8/1817799696802302198", + "2/3458654217260499072", + "8/2951495124159360418", + "7/1221489356083788397", + "5/5984489796346927095", + "5/7923170820369317405", + "6/2092402915128447866", + "2/3814054640265963432", + "9/500346823083616129", + "9/5321809387447652109", + "1/14075469470787881", + "3/7025227066145448203", + "0/8074754476218727900", + "1/5704629544502855301", + "8/2851603535592584718", + "6/1734045398297044056", + "9/3387925455116118679", + "0/8559962400526410910", + "8/1544236669243609388", + "8/7450919783730919038", + "6/359754965149295976", + "3/5677064444015133503", + "1/1541365690629957721", + "9/5223859367676344409", + "0/8248807204760783970", + "7/518965495849648047", + "2/7639970788620392152", + "1/2895054497496824431", + "5/3001864397933680405", + "8/4412828778005392728", + "9/2848716591616520619", + "2/988155513756298762", + "6/816043156630045946", + "9/7211234512052031229", + "6/6211095764422870596", + "3/7670077723001265143", + "3/3373110731915244463", + "9/1593076713359023369", + "6/7910871766647202766", + "2/5896514182903533642", + "0/4206927045461146330", + "2/4809775450270578102", + "2/4884461931699427272", + "9/2996530417954448059", + "8/2211113649547618348", + "2/3331890870838318602", + "1/8678744559790903201", + "8/7318245119328234478", + "8/2170222870989730858", + "6/3660955677983133076", + "3/1948622210396124163", + "7/7019785706168847987", + "2/8318652809636645252", + "2/190949268332181832", + "2/4060807068226685052", + "4/1952850522372485444", + "5/1497595267382616485", + "9/4491238733320502279", + "8/1877631225106076958", + "0/2312717721228006740", + "0/4535177541660557750", + "0/5236141067540503850", + "2/4438122947900743932", + "4/1536340188331302614", + "6/4181209914436403756", + "0/1624565670482030990", + "3/7188781915339190373", + "7/4291068726621898897", + "9/5284383697341929739", + "3/6252496108574304943", + "5/8900621394037107425", + "6/29850514377114056", + "3/7336836280153714713", + "9/458261014440521709", + "2/3320659143346840522", + "6/2018401352224058176", + "7/8440511112304395517", + "6/7944368636016665206", + "4/5989980119573415764", + "9/8671589176560716709", + "2/1121802978946440822", + "8/9157273105197170208", + "3/2593618598192817563", + "4/1992381929919490554", + "3/7909799095527677303", + "8/8748990210559213098", + "6/2130164126221604586", + "1/583721583512164771", + "2/6994780871250102312", + "0/745353349180485440", + "4/3612248732219924804", + "5/273835879855404065", + "0/1816261651024674810", + "7/8653107013847210027", + "8/6466286615804333078", + "4/1652348751965530264", + "9/6635087989826670439", + "7/5210903041834772657", + "2/7378386104625448952", + "7/1065864042260651097", + "3/1893178744967546483", + "6/5140045229454791146", + "0/5095386043245893710", + "7/7533479196955520527", + "5/8618880256706554495", + "3/8590698985623726633", + "3/8838433116730743043", + "9/8174720025133281489", + "2/2691839397189024092", + "6/5887812645412471266", + "4/1424273517731910834", + "0/8174355422039456710", + "7/8012048608489936077", + "3/872247437425221173", + "9/2301176534478093279", + "9/4589868665188562869", + "3/5612184339400495333", + "9/3845270106757521199", + "3/5523200696352135293", + "1/4681695433163870341", + "1/717624175651442541", + "0/5310610516056872460", + "1/7069372682320558701", + "0/1011462132209711740", + "6/5921247169041448766", + "9/7328703294685979049", + "5/8872344703372007545", + "0/2316741519752623840", + "6/2263894017437142966", + "1/749825744550307191", + "6/5850348682071581526", + "5/5715475686357816085", + "7/1206918545507786617", + "8/4685982476860194608", + "7/8751046373180481487", + "4/1898765554289435874", + "9/6477827643818359699", + "9/5237591301199042129", + "4/2408855355088563454", + "2/4143039283540666892", + "7/5795396259748844387", + "4/2078438344509092064", + "1/1509027799117919611", + "2/3361738975096649902", + "2/5105970442481665412", + "2/6262833680036717772", + "3/7597316024053104803", + "0/7786129994382148270", + "7/4899165504429926447", + "3/3782049197069052043", + "5/8020494623608521015", + "6/4141100167873444696", + "3/4038728624667921633", + "8/6524583540124634678", + "4/8587813319965183684", + "3/1961107338192758323", + "1/4837623180596207231", + "7/8018390819196050947", + "4/4598908326120234754", + "2/343712345541196082", + "2/4031682617126197662", + "3/5678946914587284973", + "0/872810318700275800", + "0/5662048210370549490", + "4/2113746151335174254", + "5/7233947125794727195", + "1/9035850342572073941", + "5/6177117516195607825", + "0/2577209497050625740", + "0/5391573223707446900", + "2/1425570447847942332", + "7/2522605662544500287", + "3/2942697230019206743", + "5/3041893866003804805", + "5/8435105018700001315", + "0/9154995939919586240", + "5/1292185307802124065", + "0/7040049099142531610", + "4/6137540671867540984", + "5/3740757761612932885", + "3/4224984605186758403", + "5/4587640472934137525", + "7/401228575530480567", + "4/2221521638726606314", + "0/8980955969341634350", + "8/2333379012598892558", + "3/6417018655147166613", + "3/668736714293084943", + "4/6575628826252947034", + "4/8939233495397566104", + "8/8752842349033563618", + "6/6729344198314507186", + "9/4589053409498800829", + "9/4818728989187517539", + "0/5865787462267702290", + "3/5296422470883699013", + "7/7093999111073352637", + "0/4972355219662424160", + "1/1521525189450511911", + "0/5525969863157062720", + "0/4407176300058130190", + "2/7932321634834033872", + "8/1753584995044876578", + "8/3398729607027236238", + "6/8270191711247718006", + "7/464371662989350047", + "1/502669310103232921", + "4/2358544044324210674", + "7/5427204764500978847", + "6/8895541135524891466", + "6/4979034230557868646", + "3/2871505395744595603", + "9/3446775220762039299", + "8/5803985753867881018", + "7/2926517149182447467", + "3/5443482058184918283", + "7/3662220453068700207", + "8/855003243676561208", + "3/6650527556712225633", + "5/694552105644667655", + "0/8670411436802594950", + "2/5903891121090469772", + "2/3858648399710246282", + "0/2900995232440866870", + "1/7578384525362077071", + "4/1330880949184825654", + "5/1873873604942283725", + "4/355196322744063784", + "2/6066991611797097412", + "5/8933323532541420315", + "7/7703510059431051957", + "5/9131647213933311395", + "5/3039385667644127805", + "2/8810996325338888012", + "3/8369726693628958723", + "0/2636876631900746000", + "2/8240330771143298962", + "6/2538227379619963996", + "8/120313242285823818", + "8/8312333486248928608", + "0/1661755431939727730", + "8/190889763016197518", + "2/7441894688341822402", + "8/5816697722727794828", + "7/7986222275949718497", + "2/5651571978018196192", + "7/2213854042391524827", + "2/6729177152338574322", + "5/307262761595652215", + "4/3533051141446205734", + "8/3823016094022735628", + "1/6704035059124956161", + "9/7476764964031990299", + "8/8972437245679988878", + "7/8330857133744603857", + "7/5092090866899578807", + "2/9195839135192014792", + "1/7008220911332030471", + "4/5235059319307947094", + "6/8484247471823743616", + "5/4526353411592822145", + "8/2108794643520371228", + "6/7608090862555160206", + "8/4553190892628609908", + "1/4983562171924596171", + "5/6722825344642673975", + "6/8640124944368483276", + "2/8712411652088153532", + "8/5310203146636038308", + "2/35114461466894952", + "7/3324993432000333367", + "2/8830409866324801212", + "0/6688495297820529400", + "2/7660005044867218612", + "2/8965889610121040232", + "0/5616765306113513760", + "8/3601044358235163068", + "5/8982342833854980155", + "1/3635993485851908911", + "0/8892256834371837940", + "8/8302517591012197968", + "1/3395264497803964301", + "1/8759818440631828231", + "0/2844368355898131150", + "4/2775354675143660364", + "3/7447514522793883003", + "2/8075916088582651752", + "8/5413551715422859468", + "3/5970290422990604833", + "7/6655797521296150897", + "4/5340150727178111704", + "5/7718379222613271855", + "1/2640643288392457331", + "2/6697946127959962002", + "7/7531846030051823727", + "6/1905951965074915616", + "4/9059013157626347134", + "2/8572307410535889942", + "5/2181674704836382945", + "0/4131979252775564660", + "3/972833217743589653", + "4/2539188405944407944", + "1/3201683093072411401", + "2/1004724000688170552", + "1/5791845589992018391", + "4/3104569774069019544", + "7/5653438432516478377", + "6/8314277176434555216", + "1/3250960644595960561", + "2/5055337112494986792", + "8/5344554772259361138", + "0/827637235619696700", + "2/4619398241055342002", + "9/6720734663238538289", + "2/7628832110586514272", + "0/8912770287944959640", + "8/2102286902113079038", + "4/8575320871477379914", + "1/1598642615902728091", + "6/6245114840613858736", + "9/6439756559392882549", + "6/5853143164428267926", + "7/4490111009466934127", + "9/6486240360055185349", + "0/1714179919639209140", + "9/895087578703445419", + "8/850247708893797718", + "0/6926861501104038640", + "3/3126697082837770783", + "8/6414014966671914968", + "5/3467431173460049595", + "6/671333961983508276", + "8/5518816348824288608", + "0/5691478655558471940", + "9/7320406662074662319", + "5/4831227871500338855", + "9/7987188484910818339", + "4/2077631988930213594", + "4/3029172804364718444", + "4/5044815003961393834", + "9/1450911256023233609", + "4/9101749090883234364", + "2/8589022505510040492", + "5/6420883495034926815", + "6/4467181045067221266", + "8/8018764868387221158", + "3/8711433518143634783", + "0/8914373840576530460", + "6/8371627752127414526", + "3/3941710224166093243", + "5/4446065416973979525", + "7/6308382789939652567", + "1/7211195420814974551", + "2/3241698566682350962", + "4/3071420849808952634", + "7/1235290824774309227", + "0/2394376384642802610", + "4/5981427040322860314", + "0/8525852089524367540", + "9/527866204374499949", + "2/8476317259911823332", + "7/4295975086255889857", + "8/8198966742577782558", + "4/8237335534011237034", + "2/2036593766115331822", + "1/8792773110957859071", + "0/382192830845740", + "0/4271594026154238810", + "2/7521248020904051532", + "1/699851464459839621", + "3/7155063617421757953", + "6/7631234123537685676", + "2/48806118069599122", + "1/8521350842968645031", + "9/1743535594735837619", + "0/4359656527118517790", + "2/3560518805923687412", + "3/6534950024089237653", + "8/2933918162371857328", + "1/6934274209926083441", + "0/5479804697298437040", + "8/8166567780875508858", + "8/1114131680001999378", + "5/3519100519525308685", + "5/3136056352154258175", + "1/6779152075472735761", + "3/7262864065882180183", + "4/3498186111862863494", + "0/7939209027710417260", + "5/6871413180860072695", + "4/4197586603012252094", + "7/4746634862485321777", + "2/7682763949440213862", + "6/4639555088396105806", + "3/6719017312476039173", + "6/5670436731781298816", + "4/1280051458878161154", + "8/6407860248699819028", + "8/9222040047249844458", + "2/2258356015996424982", + "6/6053199460550253196", + "6/663371804375479746", + "6/7414362022807735886", + "9/5519688840474660369", + "8/7934087128032287688", + "3/1897304386289042073", + "5/1161741402650279815", + "2/2846077957341649162", + "2/9007290405290983542", + "1/6756998997722282931", + "2/1176572081035769202", + "9/8631632138367041389", + "7/8883013410742572597", + "2/4266902425398243022", + "6/5485007422699501026", + "4/7829108274631729144", + "1/4015384826008232441", + "2/6477322036975819022", + "9/8904427671570886709", + "3/2818779714896100323", + "5/2925805146826157965", + "6/7905626046821248566", + "3/3209668789788631363", + "6/6881348944324437566", + "0/2705757124791257540", + "4/6596804483322828644", + "9/1866656240742923089", + "4/1711406166813621424", + "8/4839096533544296748", + "4/5187745540396342534", + "6/9216743680104853656", + "8/5201878381238309928", + "7/8202757797117286827", + "6/2174478955555075206", + "9/6148125901806235509", + "3/9111435395978948793", + "3/7877041305633432073", + "8/7884978284759726758", + "1/3467918854858782411", + "3/2217185802136710963", + "1/4327188527532008681", + "6/5258227976100010836", + "2/2269139573228315182", + "1/5591804857460095161", + "7/1445996615629653457", + "2/2601293953407105182", + "4/6887843291475895664", + "6/2460204596706963286", + "4/444304497790424864", + "8/5220589817387066288", + "3/1253541007794808503", + "1/3911910628295666091", + "4/59716680249630594", + "9/2093118278506482579", + "1/1158926960827141201", + "6/1797018019315524786", + "9/7754430571578289629", + "4/3423417731994324804", + "0/1290125494127258440", + "9/2404479006277926819", + "7/7629957593116436887", + "2/7073222024036738422", + "3/1414164488266803543", + "1/5801742470731129141", + "4/3979539599623415844", + "5/882143123730287665", + "1/1428458239061290271", + "0/676994000435173080", + "4/574787773440459534", + "1/5906491613125559131", + "5/6401029053916578925", + "3/6159642200931523803", + "1/8086389804674076911", + "8/8575597864921986288", + "3/5759390714776405273", + "5/2512285756169640385", + "5/5868191702965170715", + "5/6170331810480386005", + "7/7620853480649520257", + "4/1818332309386975354", + "4/7146542349954932994", + "9/4857160690369466719", + "9/7014553304196052439", + "4/650430037313030894", + "4/1899198978065935274", + "5/2552542350376939735", + "6/1278253081515443036", + "5/7195753311039298905", + "9/2906807740865970119", + "1/4461968112107671601", + "7/6538827515097218847", + "2/8637837000876959042", + "5/5052527902966011205", + "6/4201886104088018596", + "1/2561401897376971281", + "2/5175281955406061442", + "7/2065496480275303547", + "1/6277928329391151231", + "7/2375392543677628457", + "1/3265476741258445561", + "9/8090197319216310009", + "4/1930087082312443974", + "8/1283592059629229838", + "5/4765104382877388825", + "0/5331156255299777950", + "5/6885793821433191635", + "2/4739489777538469062", + "6/6633150904522683886", + "0/467443238919536990", + "7/5452213415901008767", + "6/157443985502498626", + "1/706035202215802381", + "0/8123663997528475060", + "0/8359515295614847330", + "0/3126257232839355290", + "4/4153750369797952064", + "6/7724797015074964736", + "7/879071612373332487", + "9/3908801549964108189", + "8/252000845733733818", + "1/7516614937743018801", + "8/5178317349502190568", + "4/7749153570476289064", + "8/868061732491369808", + "6/3184551087108569166", + "1/3496068832906934211", + "5/8423470882970990335", + "6/2386810590296817916", + "4/7965677467899022054", + "5/8904822319449499935", + "0/1614038526732100950", + "0/5562768832970375110", + "4/3303474614372738664", + "5/8711514166689472395", + "7/6911400410302718657", + "5/6148590382905320905", + "9/8369214750829593899", + "2/3825621828126009472", + "8/2745352611846923188", + "5/3968756301597113825", + "4/2031197086742672044", + "6/702363754714188626", + "7/2234249235048399867", + "1/6588076113746885561", + "7/3930464084186264507", + "0/7576885519182017190", + "5/4860533307021547215", + "8/8830476409126330768", + "2/2357055483050663332", + "4/7686469476165979294", + "9/4456218515052334319", + "3/4924928088487924483", + "5/1539827617121628825", + "8/3075410605787710128", + "2/6965310837859584912", + "5/127085980450427915", + "7/3645329034326328807", + "7/4985948676210147337", + "5/5749099325640620125", + "6/5204524307381688636", + "6/8468278637066688626", + "1/1482720389714333081", + "9/5648520030280493849", + "8/7239629677590419118", + "0/7760484713936417890", + "2/5494328499297897752", + "1/8478428305558137471", + "4/2144101802400822884", + "7/4852784191315959567", + "7/8527277705306299577", + "9/8331668506194461909", + "0/135248480526530010", + "4/8280951483532666554", + "9/3957091627414769649", + "8/2822337248695097128", + "5/5238871433779353845", + "0/7667070359968417650", + "8/7062359733772252538", + "7/6545153194676159177", + "4/2499836152988221874", + "4/4170746264075688304", + "1/1810958824167115191", + "5/4236314445855462045", + "7/6877813767929165837", + "1/4154337379679278501", + "9/4243457780979651699", + "5/8768873897543523545", + "9/1624413875839220119", + "0/5312763549025513670", + "1/4865655218439839231", + "6/7683632605710551566", + "0/8106655378383995480", + "5/3270739459103570125", + "2/8802986813103864402", + "7/2440360783262562887", + "4/8406534285417735194", + "0/4353102586922240240", + "6/842991186332112526", + "3/1461595573867758583", + "5/5873115647796296075", + "9/863637425013332369", + "5/3292948933284324555", + "9/6147700006904588769", + "6/8321160453936972866", + "2/8300865290622089992", + "6/8833666608028044106", + "8/3277846831306320158", + "6/8916036895732522496", + "8/5432205153406596908", + "0/951522184347004450", + "7/5120679132118417367", + "9/8897755575083075089", + "1/2487879609577096601", + "9/552554810110800499", + "5/1022113573246895015", + "0/2186363542906764370", + "9/5439721440003627499", + "8/6827508592796473878", + "3/853800503763044943", + "9/5897381166662672479", + "7/8951858935169278157", + "9/9029243399810617739", + "1/7844822918842053261", + "3/5487317706806274873", + "9/4439526881863295879", + "4/539720653783917974", + "1/6661917859213066551", + "2/2421036979292221272", + "8/4983123276733203258", + "7/3623614682534708437", + "9/6987810233258579269", + "4/6188804642239067724", + "4/251174881638860544", + "9/6991632320475546689", + "2/1408289917605119482", + "2/1100514084783754682", + "4/145245468037886074", + "3/1195023111293678063", + "5/4021455727242353185", + "8/217741750323602538", + "1/4637085770032801101", + "0/4085271909750238500", + "1/8064076133108271241", + "7/5648193759127886177", + "9/7556301115387688859", + "6/5022456475789425506", + "2/3681590364037504322", + "0/4971470669768233130", + "5/8523869388662426115", + "3/7713398501964457073", + "9/4152015275158923089", + "0/3943343976553677060", + "5/3683048663924206105", + "6/1572167581430474256", + "0/7387331977119620580", + "0/2442861123545071640", + "8/1842684441126998408", + "9/1322609134373595589", + "0/3327887495166663300", + "9/6535375213828833109", + "2/7431245240076187962", + "7/7764881265361360997", + "3/106690291880000073", + "8/7213026045466553548", + "0/8736918077908108260", + "3/8689321620751842943", + "9/2711191801393106699", + "1/8258114604697888361", + "1/2278694938070140491", + "1/8137590830762050421", + "2/1157175874401749732", + "2/2393905683265105622", + "1/914833984566867741", + "5/5470868599967276255", + "2/8269245622600219022", + "7/7808015660955673267", + "2/1723660646870529752", + "3/6752165843335969983", + "9/100772444504484719", + "4/2988821767380899174", + "9/7264807316116657779", + "5/905253681647524045", + "3/4264262483741500763", + "7/1551202276712815937", + "3/1712485084618323603", + "8/4264949603460030128", + "8/1294592543541156028", + "5/6400918377415310545", + "2/3671051426768008112", + "6/3423359454743746466", + "5/1310833939031690725", + "8/8462508591434067278", + "5/7537380436058238955", + "8/8205376759665737728", + "8/1412808436654720078", + "7/6997764185034744557", + "9/238086143360985989", + "3/1727101356315380553", + "9/3966065225738309379", + "1/1492946431751459391", + "0/5686563297381816910", + "7/1341820125835939497", + "0/5826292114939950750", + "5/3644850763406327365", + "7/8870166577141855027", + "0/7181079459187717200", + "9/449745917094820169", + "3/8234740753935892103", + "1/7835015323122119881", + "9/6129375334023466699", + "3/8260648571591055253", + "3/5073310188870600023", + "8/5836345566678905648", + "6/8037708479870036", + "0/4400078146530134560", + "5/1466222815055317245", + "7/2647675969224624117", + "2/1921350562317843912", + "8/2444157513457650308", + "7/6042493500856483207", + "4/6341932016479485734", + "1/4102563741824677311", + "7/5044072376324951557", + "6/6801945299381036406", + "3/4241368665895328383", + "9/4162047082208818189", + "7/1383231684650162397", + "9/4594966003266577669", + "6/1429244327569781186", + "2/3467674422266496512", + "9/2285523391602512109", + "2/6093224344820427462", + "4/2031888507042504194", + "9/5247484245686797759", + "0/7782234173617220650", + "3/972625185981469383", + "7/6646976596368328907", + "3/3606007426502286063", + "1/4782699390497472121", + "8/7447295008554432268", + "0/2382367254484807440", + "9/6071786493058479989", + "5/2226496834504349285", + "2/1922287894164474382", + "2/3511652536354576532", + "0/7751056344643635500", + "1/2259603033209151421", + "9/1034519962846643389", + "6/3183681741837344496", + "8/5350083556248108658", + "2/2392275067362861232", + "9/373989062762737559", + "7/3188097870371432997", + "2/2247980248774388932", + "6/882780958676786236", + "3/4668510707019994993", + "1/2367364438333973441", + "0/8402085213686651480", + "1/2870521893641731801", + "7/5724962268081929307", + "0/304713568393608940", + "6/7933157126057351526", + "2/2036734521251179672", + "2/2015534260396734612", + "5/2653408946674684855", + "6/3181907783806246696", + "9/1451571769874623559", + "7/6650140790474745777", + "7/1622664686779566687", + "0/3248101090915254240", + "6/1528436150962474626", + "1/595080673197488621", + "0/3635159105848535000", + "3/2292328259375526973", + "1/1468244114898007481", + "6/6240990853318550746", + "1/8563448397722366971", + "0/4560222660008055160", + "3/2194503284663399103", + "8/1587744841294968578", + "0/3599551257327653500", + "6/661743518194441226", + "4/1129143441810890984", + "5/8946411716667957295", + "5/8014786438046707525", + "2/7190355700736806062", + "7/2528919749727969817", + "4/4867408899931044784", + "5/1735655375616201575", + "7/4886477168100164217", + "2/2793830142902597022", + "2/9211090345999964852", + "8/1998680687849490988", + "8/4940501149097505358", + "7/5914793817468412587", + "8/5422409625223695338", + "6/5015762139187637156", + "5/6794697067286320305", + "1/8042463532360018221", + "5/7296430777833940625", + "9/8417031381164827169", + "4/4900860499088253494", + "8/6900438556028158278", + "4/1565888348476572434", + "1/9036784514919469751", + "8/4719547356652238848", + "5/4086959822428721955", + "4/8631796898569087124", + "7/7232421868989453537", + "1/8957211329393964731", + "0/3676199748666491220", + "1/4375059909986428201", + "7/7545073373552772837", + "0/2465777636628963110", + "5/2185787635578698415", + "7/2520773292955632657", + "3/2759235138128174463", + "5/2086880656096660245", + "9/8770223496016217479", + "1/4592530958900628591", + "4/5260489449739864494", + "8/4821941168720689408", + "3/6444673042413325193", + "9/1142417114822897289", + "8/8769271304048916168", + "7/2927854883065529947", + "6/1319650605665504536", + "8/2644494331012803598", + "7/7343448002219482027", + "2/1659109981496847752", + "4/4274822910005691554", + "9/5193653758854258669", + "3/7187335788543296513", + "9/7083283826785949129", + "1/8156587513611022491", + "7/936785450805188957", + "8/581473255569205888", + "8/8359905350987680118", + "0/6824361808315582980", + "0/2365360498028818990", + "4/4679564420883816974", + "2/9052721950503496162", + "5/373889682449698105", + "2/687458591441253982", + "4/6873654131695864994", + "4/7051360079525549264", + "5/5114324981618208775", + "8/632938901957171468", + "8/4483823026297426048", + "0/1992198895632691780", + "4/5364415615309207304", + "8/5028944994966102398", + "5/3528456008072524165", + "6/7112118562958864186", + "8/118252258353190348", + "8/2275450690631514988", + "4/1946978663155343514", + "8/6233681473690215658", + "9/6279806586835225569", + "4/2879373078954726974", + "8/1549013336660939818", + "3/7214095406957008373", + "6/5393339504617187656", + "4/2517881929180422564", + "1/4852545107023879861", + "1/2724568998387593391", + "8/6084849172640607768", + "8/7270008942699250618", + "6/1071140363648738466", + "5/2735143308259190635", + "0/6300927560245599240", + "8/7518193587766901888", + "9/7647412694344407439", + "0/4562948372310139220", + "8/6552173265712792698", + "9/4183891267246019379", + "3/8974086741810969903", + "4/8110779481826086554", + "0/5774585348602500240", + "8/9185432023852007248", + "9/8425069436805932249", + "3/4814780557630045173", + "3/6939461388878140773", + "4/252775699533043464", + "1/6716260305285319481", + "1/5173490241528401041", + "9/3352289122345949059", + "8/8707631236541420788", + "4/74918524375105304", + "0/1615143097848458660", + "1/8813234378625256321", + "7/8041270930705172347", + "1/3487757670592986021", + "2/8456737402245879062", + "8/399252496511512748", + "4/1555213314542952784", + "9/3645469057361149369", + "2/7635621231901626812", + "4/2521127718669456024", + "0/5518221047860740040", + "3/6681379410142047483", + "6/7289241901887757306", + "3/4064206682021827313", + "6/3639274277177286586", + "4/2118606892634287314", + "8/5723376961365526028", + "2/4564470594450545832", + "7/9220837466571269117", + "5/4979706053165616975", + "1/4942602512188189781", + "2/5716547936123925682", + "2/3423405179377533632", + "3/1415836257264249293", + "6/5972786736005481436", + "0/3571338828365994540", + "6/7312849313598671226", + "1/2843937336606898491", + "6/474699454290027366", + "6/6330292556001246036", + "4/909442568912635284", + "4/7680997291201013464", + "5/4147077747163398035", + "5/6151324909537757085", + "1/6796073987499682721", + "0/6898977638359350870", + "8/3219113269792995798", + "7/8266858466495034367", + "1/2046757438416173201", + "8/6199426965220789518", + "4/135576677871906674", + "7/5615264629205870367", + "9/1266112543186503199", + "1/2290430044898376691", + "7/2473820833262195257", + "4/4902025119334400174", + "5/7252885561263514475", + "4/1972762746623246264", + "1/3883659452344819891", + "7/4594061062940913797", + "3/6804008292710344233", + "2/4371351687457744782", + "7/8626298746108887537", + "5/6811424934209101235", + "0/4091246605850584600", + "7/8952137357510751687", + "5/8030113276384292515", + "0/6468720473731626110", + "0/1709827542292688030", + "1/2348393436179937711", + "2/529119099527257202", + "0/7133113854996511690", + "4/2524658459820004394", + "1/3667237030693236591", + "9/1589294530693123519", + "0/4274383312622040130", + "2/1163537186953528212", + "2/6096448856486430472", + "8/2033110649345289868", + "3/4621559582526883", + "3/3757208051196364423", + "7/1204667184986944787", + "0/6360583494772778210", + "1/3569713072475604671", + "2/7563325215904331272", + "6/7903663697734943116", + "0/525724955230631450", + "7/430441809209082177", + "7/6017123990187354747", + "4/395274037685711754", + "5/4977684120300672415", + "3/8027587767931831963", + "4/8900094584109782174", + "5/2218303420591842365", + "3/784391704418059753", + "3/251699416383259213", + "5/6561797027020455735", + "0/5507674698517668830", + "2/8457888881356196992", + "9/5543853391837420189", + "2/6887509072600755632", + "4/7192785171423243974", + "3/8714102514041342553", + "8/5858516809446292878", + "6/6884469538299232386", + "4/1319917548969715264", + "8/2918419924257979458", + "7/2476500663606532567", + "5/2390914022658142075", + "3/1805343257334652033", + "3/3651659624787259053", + "1/9174046812112380561", + "3/1026000697424124743", + "1/6062369180416756561", + "7/5588700957299342557", + "7/6171623746897002507", + "3/6760688721988911433", + "6/1797166613779120146", + "6/7638391236373974596", + "7/8248643795635847377", + "9/6286836393172650559", + "1/5629500555325009731", + "1/8775951377860098931", + "2/3604199546257090342", + "2/7882105284498008202", + "6/6770333854471087216", + "3/3387115887010667203", + "1/1869531495279110281", + "9/5360168191648460129", + "9/7067380158973065589", + "3/753530946426922573", + "9/6773931714108134209", + "1/4620635304498900011", + "5/7880543326193537785", + "0/7371360686661874960", + "6/7695640292889283736", + "5/1532706046532373415", + "8/4180330058868420438", + "6/7286177474729595686", + "2/796665065994057342", + "6/5616816786917253856", + "8/1470080093955377338", + "9/7750581977249520769", + "8/780980038757016298", + "1/3565240653778188161", + "9/8293254174058128909", + "1/5911004139119139731", + "7/4526534244162930567", + "8/8290854360232801338", + "5/6438306373201434965", + "2/6256513592585992142", + "7/2523197298495909397", + "2/8299653299512327592", + "9/550155573094805859", + "9/6396008363000609299", + "8/1875907454040571718", + "8/4834513592876705708", + "0/4718565474006856330", + "3/2895881574226811213", + "6/1282083184332552466" +] \ No newline at end of file diff --git a/infer_api.py b/infer_api.py new file mode 100644 index 0000000000000000000000000000000000000000..aab1a4a49201c0cb3ae8dc3a908c84c284e0b40b --- /dev/null +++ b/infer_api.py @@ -0,0 +1,881 @@ +from PIL import Image +import glob + +import io +import argparse +import inspect +import os +import random +import tempfile +from typing import Dict, Optional, Tuple +from omegaconf import OmegaConf +import numpy as np + +import torch + +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.utils import check_min_version +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection +from torchvision import transforms + +from canonicalize.models.unet_mv2d_condition import UNetMV2DConditionModel +from canonicalize.models.unet_mv2d_ref import UNetMV2DRefModel +from canonicalize.pipeline_canonicalize import CanonicalizationPipeline +from einops import rearrange +from torchvision.utils import save_image +import json +import cv2 + +import onnxruntime as rt +from huggingface_hub.file_download import hf_hub_download +from huggingface_hub import list_repo_files +from rm_anime_bg.cli import get_mask, SCALE + +import argparse +import os +import cv2 +import glob +import numpy as np +import matplotlib.pyplot as plt +from typing import Dict, Optional, List +from omegaconf import OmegaConf, DictConfig +from PIL import Image +from pathlib import Path +from dataclasses import dataclass +from typing import Dict +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from torchvision.utils import make_grid, save_image +from accelerate.utils import set_seed +from tqdm.auto import tqdm +from einops import rearrange, repeat +from multiview.pipeline_multiclass import StableUnCLIPImg2ImgPipeline + +import os +import imageio +import numpy as np +import torch +import cv2 +import glob +import matplotlib.pyplot as plt +from PIL import Image +from torchvision.transforms import v2 +from pytorch_lightning import seed_everything +from omegaconf import OmegaConf +from tqdm import tqdm + +from slrm.utils.train_util import instantiate_from_config +from slrm.utils.camera_util import ( + FOV_to_intrinsics, + get_circular_camera_poses, +) +from slrm.utils.mesh_util import save_obj, save_glb +from slrm.utils.infer_util import images_to_video + +import cv2 +import numpy as np +import os +import trimesh +import argparse +import torch +import scipy +from PIL import Image + +from refine.mesh_refine import geo_refine +from refine.func import make_star_cameras_orthographic +from refine.render import NormalsRenderer, calc_vertex_normals + +import pytorch3d +from pytorch3d.structures import Meshes +from sklearn.neighbors import KDTree + +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry + +check_min_version("0.24.0") +weight_dtype = torch.float16 +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +class BkgRemover: + def __init__(self, force_cpu: Optional[bool] = True): + session_infer_path = hf_hub_download( + repo_id="skytnt/anime-seg", filename="isnetis.onnx", + ) + providers: list[str] = ["CPUExecutionProvider"] + if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers(): + providers = ["CUDAExecutionProvider"] + + self.session_infer = rt.InferenceSession( + session_infer_path, providers=providers, + ) + + def remove_background( + self, + img: np.ndarray, + alpha_min: float, + alpha_max: float, + ) -> list: + img = np.array(img) + mask = get_mask(self.session_infer, img) + mask[mask < alpha_min] = 0.0 + mask[mask > alpha_max] = 1.0 + img_after = (mask * img).astype(np.uint8) + mask = (mask * SCALE).astype(np.uint8) + img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8) + return Image.fromarray(img_after) + + +def process_image(image, totensor, width, height): + assert image.mode == "RGBA" + + # Find non-transparent pixels + non_transparent = np.nonzero(np.array(image)[..., 3]) + min_x, max_x = non_transparent[1].min(), non_transparent[1].max() + min_y, max_y = non_transparent[0].min(), non_transparent[0].max() + image = image.crop((min_x, min_y, max_x, max_y)) + + # paste to center + max_dim = max(image.width, image.height) + max_height = int(max_dim * 1.2) + max_width = int(max_dim / (height/width) * 1.2) + new_image = Image.new("RGBA", (max_width, max_height)) + left = (max_width - image.width) // 2 + top = (max_height - image.height) // 2 + new_image.paste(image, (left, top)) + + image = new_image.resize((width, height), resample=Image.BICUBIC) + image = np.array(image) + image = image.astype(np.float32) / 255. + assert image.shape[-1] == 4 # RGBA + alpha = image[..., 3:4] + bg_color = np.array([1., 1., 1.], dtype=np.float32) + image = image[..., :3] * alpha + bg_color * (1 - alpha) + return totensor(image) + + +@torch.no_grad() +def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, + text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type, + use_noise=True, noise_d=256, crop=False, seed=100, timestep=20): + set_seed(seed) + + totensor = transforms.ToTensor() + + prompts = "high quality, best quality" + prompt_ids = tokenizer( + prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, + return_tensors="pt" + ).input_ids[0] + + # (B*Nv, 3, H, W) + B = 1 + if input_image.mode != "RGBA": + # remove background + input_image = bkg_remover.remove_background(input_image, 0.1, 0.9) + imgs_in = process_image(input_image, totensor, val_width, val_height) + imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W") + + with torch.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=weight_dtype): + imgs_in = imgs_in.to(device=device) + # B*Nv images + out = validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator, + num_inference_steps=timestep, prompt_ids=prompt_ids, + height=val_height, width=val_width, unet_condition_type=unet_condition_type, + use_noise=use_noise, **validation,) + out = rearrange(out, "B C f H W -> (B f) C H W", f=1) + + img_buf = io.BytesIO() + save_image(out[0], img_buf, format='PNG') + img_buf.seek(0) + img = Image.open(img_buf) + + torch.cuda.empty_cache() + return img + + +######### Multi View Part ############# +weight_dtype = torch.float16 +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +def tensor_to_numpy(tensor): + return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + + +@dataclass +class TestConfig: + pretrained_model_name_or_path: str + pretrained_unet_path:Optional[str] + revision: Optional[str] + validation_dataset: Dict + save_dir: str + seed: Optional[int] + validation_batch_size: int + dataloader_num_workers: int + save_mode: str + local_rank: int + + pipe_kwargs: Dict + pipe_validation_kwargs: Dict + unet_from_pretrained_kwargs: Dict + validation_grid_nrow: int + camera_embedding_lr_mult: float + + num_views: int + camera_embedding_type: str + + pred_type: str + regress_elevation: bool + enable_xformers_memory_efficient_attention: bool + + cond_on_normals: bool + cond_on_colors: bool + + regress_elevation: bool + regress_focal_length: bool + + + +def convert_to_numpy(tensor): + return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + +def save_image(tensor): + ndarr = convert_to_numpy(tensor) + return save_image_numpy(ndarr) + +def save_image_numpy(ndarr): + im = Image.fromarray(ndarr) + # pad to square + if im.size[0] != im.size[1]: + size = max(im.size) + new_im = Image.new("RGB", (size, size)) + # set to white + new_im.paste((255, 255, 255), (0, 0, size, size)) + new_im.paste(im, ((size - im.size[0]) // 2, (size - im.size[1]) // 2)) + im = new_im + # resize to 1024x1024 + im = im.resize((1024, 1024), Image.LANCZOS) + return im + +def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3): + if cfg.seed is None: + generator = None + else: + generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed) + + images_cond = [] + results = {} + + torch.cuda.empty_cache() + images_cond.append(data['image_cond_rgb'][:, 0].cuda()) + imgs_in = torch.cat([data['image_cond_rgb']]*2, dim=0).cuda() + num_views = imgs_in.shape[1] + imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W) + + target_h, target_w = imgs_in.shape[-2], imgs_in.shape[-1] + + normal_prompt_embeddings, clr_prompt_embeddings = data['normal_prompt_embeddings'].cuda(), data['color_prompt_embeddings'].cuda() + prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0) + prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") + + # B*Nv images + unet_out = pipeline( + imgs_in, None, prompt_embeds=prompt_embeddings, + generator=generator, guidance_scale=3.0, output_type='pt', num_images_per_prompt=1, + height=cfg.height, width=cfg.width, + num_inference_steps=40, eta=1.0, + num_levels=num_levels, + ) + + for level in range(num_levels): + out = unet_out[level].images + bsz = out.shape[0] // 2 + + normals_pred = out[:bsz] + images_pred = out[bsz:] + + if num_levels == 2: + results[level+1] = {'normals': [], 'images': []} + else: + results[level] = {'normals': [], 'images': []} + + for i in range(bsz//num_views): + img_in_ = images_cond[-1][i].to(out.device) + for j in range(num_views): + view = VIEWS[j] + idx = i*num_views + j + normal = normals_pred[idx] + color = images_pred[idx] + + ## save color and normal--------------------- + new_normal = save_image(normal) + new_color = save_image(color) + + if num_levels == 2: + results[level+1]['normals'].append(new_normal) + results[level+1]['images'].append(new_color) + else: + results[level]['normals'].append(new_normal) + results[level]['images'].append(new_color) + + torch.cuda.empty_cache() + return results + + +def load_multiview_pipeline(cfg): + pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained( + cfg.pretrained_path, + torch_dtype=torch.float16,) + pipeline.unet.enable_xformers_memory_efficient_attention() + if torch.cuda.is_available(): + pipeline.to(device) + return pipeline + + +class InferAPI: + def __init__(self, + canonical_configs, + multiview_configs, + slrm_configs, + refine_configs): + self.canonical_configs = canonical_configs + self.multiview_configs = multiview_configs + self.slrm_configs = slrm_configs + self.refine_configs = refine_configs + + repo_id = "hyz317/StdGEN" + all_files = list_repo_files(repo_id, revision="main") + for file in all_files: + if os.path.exists(file): + continue + hf_hub_download(repo_id, file, local_dir="./ckpt") + + self.canonical_infer = InferCanonicalAPI(self.canonical_configs) + self.multiview_infer = InferMultiviewAPI(self.multiview_configs) + self.slrm_infer = InferSlrmAPI(self.slrm_configs) + self.refine_infer = InferRefineAPI(self.refine_configs) + + def genStage1(self, img, seed): + return self.canonical_infer.gen(img, seed) + + def genStage2(self, img, seed, num_levels): + return self.multiview_infer.gen(img, seed, num_levels) + + def genStage3(self, img): + return self.slrm_infer.gen(img) + + def genStage4(self, meshes, imgs): + return self.refine_infer.refine(meshes, imgs) + + +############## Refine ############## +def fix_vert_color_glb(mesh_path): + from pygltflib import GLTF2, Material, PbrMetallicRoughness + obj1 = GLTF2().load(mesh_path) + obj1.meshes[0].primitives[0].material = 0 + obj1.materials.append(Material( + pbrMetallicRoughness = PbrMetallicRoughness( + baseColorFactor = [1.0, 1.0, 1.0, 1.0], + metallicFactor = 0., + roughnessFactor = 1.0, + ), + emissiveFactor = [0.0, 0.0, 0.0], + doubleSided = True, + )) + obj1.save(mesh_path) + + +def srgb_to_linear(c_srgb): + c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4) + return c_linear.clip(0, 1.) + + +def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True): + # convert from pytorch3d meshes to trimesh mesh + vertices = meshes.verts_packed().cpu().float().numpy() + triangles = meshes.faces_packed().cpu().long().numpy() + np_color = meshes.textures.verts_features_packed().cpu().float().numpy() + if save_glb_path.endswith(".glb"): + # rotate 180 along +Y + vertices[:, [0, 2]] = -vertices[:, [0, 2]] + + if apply_sRGB_to_LinearRGB: + np_color = srgb_to_linear(np_color) + assert vertices.shape[0] == np_color.shape[0] + assert np_color.shape[1] == 3 + assert 0 <= np_color.min() and np_color.max() <= 1.001, f"min={np_color.min()}, max={np_color.max()}" + np_color = np.clip(np_color, 0, 1) + mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color) + mesh.remove_unreferenced_vertices() + # save mesh + mesh.export(save_glb_path) + if save_glb_path.endswith(".glb"): + fix_vert_color_glb(save_glb_path) + print(f"saving to {save_glb_path}") + + +def calc_horizontal_offset(target_img, source_img): + target_mask = target_img.astype(np.float32).sum(axis=-1) > 750 + source_mask = source_img.astype(np.float32).sum(axis=-1) > 750 + best_offset = -114514 + for offset in range(-200, 200): + offset_mask = np.roll(source_mask, offset, axis=1) + overlap = (target_mask & offset_mask).sum() + if overlap > best_offset: + best_offset = overlap + best_offset_value = offset + return best_offset_value + + +def calc_horizontal_offset2(target_mask, source_img): + source_mask = source_img.astype(np.float32).sum(axis=-1) > 750 + best_offset = -114514 + for offset in range(-200, 200): + offset_mask = np.roll(source_mask, offset, axis=1) + overlap = (target_mask & offset_mask).sum() + if overlap > best_offset: + best_offset = overlap + best_offset_value = offset + return best_offset_value + + +def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20): + distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres + if normal_0 is not None and normal_1 is not None: + distract_area |= np.abs(normal_0 - normal_1).sum(axis=-1) > thres + labeled_array, num_features = scipy.ndimage.label(distract_area) + results = [] + + random_sampled_points = [] + + for i in range(num_features + 1): + if np.sum(labeled_array == i) > 1000 and np.sum(labeled_array == i) < 100000: + results.append((i, np.sum(labeled_array == i))) + # random sample a point in the area + points = np.argwhere(labeled_array == i) + random_sampled_points.append(points[np.random.randint(0, points.shape[0])]) + + results = sorted(results, key=lambda x: x[1], reverse=True) # [1:] + distract_mask = np.zeros_like(distract_area) + distract_bbox = np.zeros_like(distract_area) + for i, _ in results: + distract_mask |= labeled_array == i + bbox = np.argwhere(labeled_array == i) + min_x, min_y = bbox.min(axis=0) + max_x, max_y = bbox.max(axis=0) + distract_bbox[min_x:max_x, min_y:max_y] = 1 + + points = np.array(random_sampled_points)[:, ::-1] + labels = np.ones(len(points), dtype=np.int32) + + masks = generator.generate((color_1 * 255).astype(np.uint8)) + + outside_area = np.abs(color_0 - color_1).sum(axis=-1) < outside_thres + + final_mask = np.zeros_like(distract_mask) + for iii, mask in enumerate(masks): + mask['segmentation'] = cv2.resize(mask['segmentation'].astype(np.float32), (1024, 1024)) > 0.5 + intersection = np.logical_and(mask['segmentation'], distract_mask).sum() + total = mask['segmentation'].sum() + iou = intersection / total + outside_intersection = np.logical_and(mask['segmentation'], outside_area).sum() + outside_total = mask['segmentation'].sum() + outside_iou = outside_intersection / outside_total + if iou > ratio and outside_iou < outside_ratio: + final_mask |= mask['segmentation'] + + # calculate coverage + intersection = np.logical_and(final_mask, distract_mask).sum() + total = distract_mask.sum() + coverage = intersection / total + + if coverage < 0.8: + # use original distract mask + final_mask = (distract_mask.copy() * 255).astype(np.uint8) + final_mask = cv2.dilate(final_mask, np.ones((3, 3), np.uint8), iterations=3) + labeled_array_dilate, num_features_dilate = scipy.ndimage.label(final_mask) + for i in range(num_features_dilate + 1): + if np.sum(labeled_array_dilate == i) < 200: + final_mask[labeled_array_dilate == i] = 255 + + final_mask = cv2.erode(final_mask, np.ones((3, 3), np.uint8), iterations=3) + final_mask = final_mask > 127 + + return distract_mask, distract_bbox, random_sampled_points, final_mask + + +class InferRefineAPI: + def __init__(self, config): + self.sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda() + self.generator = SamAutomaticMaskGenerator( + model=self.sam, + points_per_side=64, + pred_iou_thresh=0.80, + stability_score_thresh=0.92, + crop_n_layers=1, + crop_n_points_downscale_factor=2, + min_mask_region_area=100, + ) + self.outside_ratio = 0.20 + + def refine(self, meshes, imgs): + fixed_v, fixed_f, fixed_t = None, None, None + flow_vert, flow_vector = None, None + last_colors, last_normals = None, None + last_front_color, last_front_normal = None, None + distract_mask = None + + mv, proj = make_star_cameras_orthographic(8, 1, r=1.2) + mv = mv[[4, 3, 2, 0, 6, 5]] + renderer = NormalsRenderer(mv,proj,(1024,1024)) + + results = [] + + for name_idx, level in zip([2, 0, 1], [2, 1, 0]): + mesh = trimesh.load(meshes[name_idx]) + new_mesh = mesh.split(only_watertight=False) + new_mesh = [ j for j in new_mesh if len(j.vertices) >= 300 ] + mesh = trimesh.Scene(new_mesh).dump(concatenate=True) + mesh_v, mesh_f = mesh.vertices, mesh.faces + + if last_colors is None: + images = renderer.render( + torch.tensor(mesh_v, device='cuda').float(), + torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(), + torch.tensor(mesh_f, device='cuda'), + ) + mask = (images[..., 3] < 0.9).cpu().numpy() + + colors, normals = [], [] + for i in range(6): + color = np.array(imgs[level]['images'][i]) + normal = np.array(imgs[level]['normals'][i]) + + if last_colors is not None: + offset = calc_horizontal_offset(np.array(last_colors[i]), color) + # print('offset', i, offset) + else: + offset = calc_horizontal_offset2(mask[i], color) + # print('init offset', i, offset) + + if offset != 0: + color = np.roll(color, offset, axis=1) + normal = np.roll(normal, offset, axis=1) + + color = Image.fromarray(color) + normal = Image.fromarray(normal) + colors.append(color) + normals.append(normal) + + if last_front_color is not None and level == 0: + original_mask, distract_bbox, _, distract_mask = get_distract_mask(self.generator, last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=self.outside_ratio) + else: + distract_mask = None + distract_bbox = None + + last_front_color = np.array(colors[0]).astype(np.float32) / 255.0 + last_front_normal = np.array(normals[0]).astype(np.float32) / 255.0 + + if last_colors is None: + from copy import deepcopy + last_colors, last_normals = deepcopy(colors), deepcopy(normals) + + # my mesh flow weight by nearest vertexs + if fixed_v is not None and fixed_f is not None and level == 1: + t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f) + + fixed_v_cpu = fixed_v.cpu().numpy() + kdtree_anchor = KDTree(fixed_v_cpu) + kdtree_mesh_v = KDTree(mesh_v) + _, idx_anchor = kdtree_anchor.query(mesh_v, k=1) + _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25) + idx_anchor = idx_anchor.squeeze() + neighbors = torch.tensor(mesh_v).cuda()[idx_mesh_v] # V, 25, 3 + # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25] + neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v).cuda()[:, None], dim=-1) + neighbor_dists[neighbor_dists > 0.06] = 114514. + neighbor_weights = torch.exp(-neighbor_dists * 1.) + neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True) + anchors = fixed_v[idx_anchor] # V, 3 + anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3 + dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01 + vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3 + vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3 + weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3 + mesh_v += weighted_vec_anchor.cpu().numpy() + + t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f) + + mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32) + mesh_f = torch.tensor(mesh_f, device='cuda') + + new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox) + + # my mesh flow weight by nearest vertexs + try: + if fixed_v is not None and fixed_f is not None and level != 0: + new_mesh_v = new_mesh.verts_packed().cpu().numpy() + + fixed_v_cpu = fixed_v.cpu().numpy() + kdtree_anchor = KDTree(fixed_v_cpu) + kdtree_mesh_v = KDTree(new_mesh_v) + _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1) + _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25) + idx_anchor = idx_anchor.squeeze() + neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3 + # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25] + neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1) + neighbor_dists[neighbor_dists > 0.06] = 114514. + neighbor_weights = torch.exp(-neighbor_dists * 1.) + neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True) + anchors = fixed_v[idx_anchor] # V, 3 + anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3 + dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01 + vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3 + vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3 + weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3 + new_mesh_v += weighted_vec_anchor.cpu().numpy() + + # replace new_mesh verts with new_mesh_v + new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures) + + except Exception as e: + pass + + notsimp_v, notsimp_f, notsimp_t = new_mesh.verts_packed(), new_mesh.faces_packed(), new_mesh.textures.verts_features_packed() + + if fixed_v is None: + fixed_v, fixed_f = simp_v, simp_f + complete_v, complete_f, complete_t = notsimp_v, notsimp_f, notsimp_t + else: + fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0) + fixed_v = torch.cat([fixed_v, simp_v], dim=0) + + complete_f = torch.cat([complete_f, notsimp_f + complete_v.shape[0]], dim=0) + complete_v = torch.cat([complete_v, notsimp_v], dim=0) + complete_t = torch.cat([complete_t, notsimp_t], dim=0) + + if level == 2: + new_mesh = Meshes(verts=[new_mesh.verts_packed()], faces=[new_mesh.faces_packed()], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[torch.ones_like(new_mesh.textures.verts_features_packed(), device=new_mesh.verts_packed().device)*0.5])) + + save_py3dmesh_with_trimesh_fast(new_mesh, meshes[name_idx].replace('.obj', '_refined.obj'), apply_sRGB_to_LinearRGB=False) + results.append(meshes[name_idx].replace('.obj', '_refined.obj')) + + # save whole mesh + save_py3dmesh_with_trimesh_fast(Meshes(verts=[complete_v], faces=[complete_f], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[complete_t])), meshes[name_idx].replace('.obj', '_refined_whole.obj'), apply_sRGB_to_LinearRGB=False) + results.append(meshes[name_idx].replace('.obj', '_refined_whole.obj')) + + return results + + +class InferSlrmAPI: + def __init__(self, config): + self.config_path = config['config_path'] + self.config = OmegaConf.load(self.config_path) + self.config_name = os.path.basename(self.config_path).replace('.yaml', '') + self.model_config = self.config.model_config + self.infer_config = self.config.infer_config + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model = instantiate_from_config(self.model_config) + state_dict = torch.load(self.infer_config.model_path, map_location='cpu') + self.model.load_state_dict(state_dict, strict=False) + self.model = self.model.to(self.device) + self.model.init_flexicubes_geometry(self.device, fovy=30.0, is_ortho=self.model.is_ortho) + self.model = self.model.eval() + + def gen(self, imgs): + imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ] + imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0 + imgs = torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2).contiguous().float() # (6, 3, 1024, 1024) + mesh_glb_fpaths = self.make3d(imgs) + return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1] + + def make3d(self, images): + input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device) + + images = images.unsqueeze(0).to(device) + images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) + + mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name + print(mesh_fpath) + mesh_basename = os.path.basename(mesh_fpath).split('.')[0] + mesh_dirname = os.path.dirname(mesh_fpath) + + with torch.no_grad(): + # get triplane + planes = self.model.forward_planes(images, input_cameras.float()) + + # get mesh + mesh_glb_fpaths = [] + for j in range(4): + mesh_glb_fpath = self.make_mesh(mesh_fpath.replace(mesh_fpath[-4:], f'_{j}{mesh_fpath[-4:]}'), planes, level=[0, 3, 4, 2][j]) + mesh_glb_fpaths.append(mesh_glb_fpath) + + return mesh_glb_fpaths + + def make_mesh(self, mesh_fpath, planes, level=None): + mesh_basename = os.path.basename(mesh_fpath).split('.')[0] + mesh_dirname = os.path.dirname(mesh_fpath) + mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb") + + with torch.no_grad(): + # get mesh + mesh_out = self.model.extract_mesh( + planes, + use_texture_map=False, + levels=torch.tensor([level]).to(device), + **self.infer_config, + ) + + vertices, faces, vertex_colors = mesh_out + vertices = vertices[:, [1, 2, 0]] + + if level == 2: + # fill all vertex_colors with 127 + vertex_colors = np.ones_like(vertex_colors) * 127 + + save_obj(vertices, faces, vertex_colors, mesh_fpath) + + return mesh_fpath + + +class InferMultiviewAPI: + def __init__(self, config): + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--num_views", type=int, default=6) + parser.add_argument("--num_levels", type=int, default=3) + parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024') + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=576) + self.cfg = parser.parse_args() + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.pipeline = load_multiview_pipeline(self.cfg) + self.results = {} + if torch.cuda.is_available(): + self.pipeline.to(device) + + self.image_transforms = [transforms.Resize(int(max(self.cfg.height, self.cfg.width))), + transforms.CenterCrop((self.cfg.height, self.cfg.width)), + transforms.ToTensor(), + transforms.Lambda(lambda x: x * 2. - 1), + ] + self.image_transforms = transforms.Compose(self.image_transforms) + + prompt_embeds_path = './multiview/fixed_prompt_embeds_6view' + self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt') + self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') + self.total_views = self.cfg.num_views + + + def process_im(self, im): + im = self.image_transforms(im) + return im + + + def gen(self, img, seed, num_levels): + set_seed(seed) + data = {} + + cond_im_rgb = self.process_im(img) + cond_im_rgb = torch.stack([cond_im_rgb] * self.total_views, dim=0) + data["image_cond_rgb"] = cond_im_rgb[None, ...] + data["normal_prompt_embeddings"] = self.normal_text_embeds[None, ...] + data["color_prompt_embeddings"] = self.color_text_embeds[None, ...] + + results = run_multiview_infer(data, self.pipeline, self.cfg, num_levels=num_levels) + for k in results: + self.results[k] = results[k] + return results + + +class InferCanonicalAPI: + def __init__(self, config): + self.config = config + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + self.config_path = config['config_path'] + self.loaded_config = OmegaConf.load(self.config_path) + + self.setup(**self.loaded_config) + + def setup(self, + validation: Dict, + pretrained_model_path: str, + local_crossattn: bool = True, + unet_from_pretrained_kwargs=None, + unet_condition_type=None, + use_noise=True, + noise_d=256, + timestep: int = 40, + width_input: int = 640, + height_input: int = 1024, + ): + self.width_input = width_input + self.height_input = height_input + self.timestep = timestep + self.use_noise = use_noise + self.noise_d = noise_d + self.validation = validation + self.unet_condition_type = unet_condition_type + self.pretrained_model_path = pretrained_model_path + + self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder") + self.feature_extractor = CLIPImageProcessor() + self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + self.unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) + self.ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="ref_unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) + + self.text_encoder.to(device, dtype=weight_dtype) + self.image_encoder.to(device, dtype=weight_dtype) + self.vae.to(device, dtype=weight_dtype) + self.ref_unet.to(device, dtype=weight_dtype) + self.unet.to(device, dtype=weight_dtype) + + self.vae.requires_grad_(False) + self.ref_unet.requires_grad_(False) + self.unet.requires_grad_(False) + + self.noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler-zerosnr") + self.validation_pipeline = CanonicalizationPipeline( + vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, ref_unet=self.ref_unet,feature_extractor=self.feature_extractor,image_encoder=self.image_encoder, + scheduler=self.noise_scheduler + ) + self.validation_pipeline.set_progress_bar_config(disable=True) + + self.bkg_remover = BkgRemover() + + def canonicalize(self, image, seed): + generator = torch.Generator(device=device).manual_seed(seed) + return inference( + self.validation_pipeline, self.bkg_remover, image, self.vae, self.feature_extractor, self.image_encoder, self.unet, self.ref_unet, self.tokenizer, self.text_encoder, + self.pretrained_model_path, generator, self.validation, self.width_input, self.height_input, self.unet_condition_type, + use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep + ) + + def gen(self, img_input, seed=0): + if np.array(img_input).shape[-1] == 4 and np.array(img_input)[..., 3].min() == 255: + # convert to RGB + img_input = img_input.convert("RGB") + img_output = self.canonicalize(img_input, seed) + + max_dim = max(img_output.width, img_output.height) + new_image = Image.new("RGBA", (max_dim, max_dim)) + left = (max_dim - img_output.width) // 2 + top = (max_dim - img_output.height) // 2 + new_image.paste(img_output, (left, top)) + + return new_image diff --git a/infer_canonicalize.py b/infer_canonicalize.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2a2b383f433aceb574056c03b9372bb1d0fa56 --- /dev/null +++ b/infer_canonicalize.py @@ -0,0 +1,215 @@ +from PIL import Image +import glob + +import io +import argparse +import inspect +import os +import random +from typing import Dict, Optional, Tuple +from omegaconf import OmegaConf +import numpy as np + +import torch + +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.utils import check_min_version +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection +from torchvision import transforms + +from canonicalize.models.unet_mv2d_condition import UNetMV2DConditionModel +from canonicalize.models.unet_mv2d_ref import UNetMV2DRefModel +from canonicalize.pipeline_canonicalize import CanonicalizationPipeline +from einops import rearrange +from torchvision.utils import save_image +import json +import cv2 + +import onnxruntime as rt +from huggingface_hub.file_download import hf_hub_download +from rm_anime_bg.cli import get_mask, SCALE + +check_min_version("0.24.0") +weight_dtype = torch.float16 +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +class BkgRemover: + def __init__(self, force_cpu: Optional[bool] = True): + session_infer_path = hf_hub_download( + repo_id="skytnt/anime-seg", filename="isnetis.onnx", + ) + providers: list[str] = ["CPUExecutionProvider"] + if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers(): + providers = ["CUDAExecutionProvider"] + + self.session_infer = rt.InferenceSession( + session_infer_path, providers=providers, + ) + + def remove_background( + self, + img: np.ndarray, + alpha_min: float, + alpha_max: float, + ) -> list: + img = np.array(img) + mask = get_mask(self.session_infer, img) + mask[mask < alpha_min] = 0.0 + mask[mask > alpha_max] = 1.0 + img_after = (mask * img).astype(np.uint8) + mask = (mask * SCALE).astype(np.uint8) + img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8) + return Image.fromarray(img_after) + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def process_image(image, totensor, width, height): + assert image.mode == "RGBA" + + # Find non-transparent pixels + non_transparent = np.nonzero(np.array(image)[..., 3]) + min_x, max_x = non_transparent[1].min(), non_transparent[1].max() + min_y, max_y = non_transparent[0].min(), non_transparent[0].max() + image = image.crop((min_x, min_y, max_x, max_y)) + + # paste to center + max_dim = max(image.width, image.height) + max_height = int(max_dim * 1.2) + max_width = int(max_dim / (height/width) * 1.2) + new_image = Image.new("RGBA", (max_width, max_height)) + left = (max_width - image.width) // 2 + top = (max_height - image.height) // 2 + new_image.paste(image, (left, top)) + + image = new_image.resize((width, height), resample=Image.BICUBIC) + image = np.array(image) + image = image.astype(np.float32) / 255. + assert image.shape[-1] == 4 # RGBA + alpha = image[..., 3:4] + bg_color = np.array([1., 1., 1.], dtype=np.float32) + image = image[..., :3] * alpha + bg_color * (1 - alpha) + return totensor(image) + + +@torch.no_grad() +def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, + text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type, + use_noise=True, noise_d=256, crop=False, seed=100, timestep=20): + set_seed(seed) + + totensor = transforms.ToTensor() + + prompts = "high quality, best quality" + prompt_ids = tokenizer( + prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, + return_tensors="pt" + ).input_ids[0] + + # (B*Nv, 3, H, W) + B = 1 + if input_image.mode != "RGBA": + # remove background + input_image = bkg_remover.remove_background(input_image, 0.1, 0.9) + imgs_in = process_image(input_image, totensor, val_width, val_height) + imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W") + + with torch.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=weight_dtype): + imgs_in = imgs_in.to(device=device) + # B*Nv images + out = validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator, + num_inference_steps=timestep, prompt_ids=prompt_ids, + height=val_height, width=val_width, unet_condition_type=unet_condition_type, + use_noise=use_noise, **validation,) + out = rearrange(out, "B C f H W -> (B f) C H W", f=1) + + img_buf = io.BytesIO() + save_image(out[0], img_buf, format='PNG') + img_buf.seek(0) + img = Image.open(img_buf) + + torch.cuda.empty_cache() + return img + + +@torch.no_grad() +def main( + input_dir: str, + output_dir: str, + pretrained_model_path: str, + validation: Dict, + local_crossattn: bool = True, + unet_from_pretrained_kwargs=None, + unet_condition_type=None, + use_noise=True, + noise_d=256, + seed: int = 42, + timestep: int = 40, + width_input: int = 640, + height_input: int = 1024, +): + *_, config = inspect.getargvalues(inspect.currentframe()) + + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder") + feature_extractor = CLIPImageProcessor() + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) + ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="ref_unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) + + text_encoder.to(device, dtype=weight_dtype) + image_encoder.to(device, dtype=weight_dtype) + vae.to(device, dtype=weight_dtype) + ref_unet.to(device, dtype=weight_dtype) + unet.to(device, dtype=weight_dtype) + + vae.requires_grad_(False) + unet.requires_grad_(False) + ref_unet.requires_grad_(False) + + # set pipeline + noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler-zerosnr") + validation_pipeline = CanonicalizationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, ref_unet=ref_unet,feature_extractor=feature_extractor,image_encoder=image_encoder, + scheduler=noise_scheduler + ) + validation_pipeline.set_progress_bar_config(disable=True) + + bkg_remover = BkgRemover() + + def canonicalize(image, width, height, seed, timestep): + generator = torch.Generator(device=device).manual_seed(seed) + return inference( + validation_pipeline, bkg_remover, image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, + pretrained_model_path, generator, validation, width, height, unet_condition_type, + use_noise=use_noise, noise_d=noise_d, crop=True, seed=seed, timestep=timestep + ) + + img_paths = sorted(glob.glob(os.path.join(input_dir, "*.png"))) + os.makedirs(output_dir, exist_ok=True) + + for path in tqdm(img_paths): + img_input = Image.open(path) + if np.array(img_input)[..., 3].min() == 255: + # convert to RGB + img_input = img_input.convert("RGB") + img_output = canonicalize(img_input, width_input, height_input, seed, timestep) + img_output.save(os.path.join(output_dir, f"{os.path.basename(path).split('.')[0]}.png")) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./configs/canonicalization-infer.yaml") + parser.add_argument("--input_dir", type=str, default="./input_cases") + parser.add_argument("--output_dir", type=str, default="./result/apose") + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + main(**OmegaConf.load(args.config), seed=args.seed, input_dir=args.input_dir, output_dir=args.output_dir) \ No newline at end of file diff --git a/infer_multiview.py b/infer_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..3b1769ebddb320905b9aee170087ed26b953d8c1 --- /dev/null +++ b/infer_multiview.py @@ -0,0 +1,274 @@ +import argparse +import os +import cv2 +import glob +import numpy as np +import matplotlib.pyplot as plt +from typing import Dict, Optional, List +from omegaconf import OmegaConf, DictConfig +from PIL import Image +from pathlib import Path +from dataclasses import dataclass +from typing import Dict +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from torchvision.utils import make_grid, save_image +from accelerate.utils import set_seed +from tqdm.auto import tqdm +from einops import rearrange, repeat +from multiview.pipeline_multiclass import StableUnCLIPImg2ImgPipeline + +weight_dtype = torch.float16 +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +def tensor_to_numpy(tensor): + return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" + +def nonzero_normalize_depth(depth, mask=None): + if mask.max() > 0: # not all transparent + nonzero_depth_min = depth[mask > 0].min() + else: + nonzero_depth_min = 0 + depth = (depth - nonzero_depth_min) / depth.max() + return np.clip(depth, 0, 1) + + +class SingleImageData(Dataset): + def __init__(self, + input_dir, + prompt_embeds_path='./multiview/fixed_prompt_embeds_6view', + image_transforms=[], + total_views=6, + ext="png", + return_paths=True, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.input_dir = Path(input_dir) + self.return_paths = return_paths + self.total_views = total_views + + self.paths = glob.glob(str(self.input_dir / f'*.{ext}')) + + print('============= length of dataset %d =============' % len(self.paths)) + self.tform = image_transforms + self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt') + self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') + + + def __len__(self): + return len(self.paths) + + + def load_rgb(self, path, color): + img = plt.imread(path) + img = Image.fromarray(np.uint8(img * 255.)) + new_img = Image.new("RGB", (1024, 1024)) + # white background + width, height = img.size + new_width = int(width / height * 1024) + img = img.resize((new_width, 1024)) + new_img.paste((255, 255, 255), (0, 0, 1024, 1024)) + offset = (1024 - new_width) // 2 + new_img.paste(img, (offset, 0)) + return new_img + + def __getitem__(self, index): + data = {} + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + color = 1.0 + cond_im_rgb = self.process_im(self.load_rgb(filename, color)) + cond_im_rgb = torch.stack([cond_im_rgb] * self.total_views, dim=0) + + data["image_cond_rgb"] = cond_im_rgb + data["normal_prompt_embeddings"] = self.normal_text_embeds + data["color_prompt_embeddings"] = self.color_text_embeds + data["filename"] = filename.split('/')[-1] + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + + def tensor_to_image(self, tensor): + return Image.fromarray(np.uint8(tensor.numpy() * 255.)) + + +@dataclass +class TestConfig: + pretrained_model_name_or_path: str + pretrained_unet_path:Optional[str] + revision: Optional[str] + validation_dataset: Dict + save_dir: str + seed: Optional[int] + validation_batch_size: int + dataloader_num_workers: int + save_mode: str + local_rank: int + + pipe_kwargs: Dict + pipe_validation_kwargs: Dict + unet_from_pretrained_kwargs: Dict + validation_grid_nrow: int + camera_embedding_lr_mult: float + + num_views: int + camera_embedding_type: str + + pred_type: str + regress_elevation: bool + enable_xformers_memory_efficient_attention: bool + + cond_on_normals: bool + cond_on_colors: bool + + regress_elevation: bool + regress_focal_length: bool + + + +def convert_to_numpy(tensor): + return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + +def save_image(tensor, fp): + ndarr = convert_to_numpy(tensor) + save_image_numpy(ndarr, fp) + return ndarr + +def save_image_numpy(ndarr, fp): + im = Image.fromarray(ndarr) + # pad to square + if im.size[0] != im.size[1]: + size = max(im.size) + new_im = Image.new("RGB", (size, size)) + # set to white + new_im.paste((255, 255, 255), (0, 0, size, size)) + new_im.paste(im, ((size - im.size[0]) // 2, (size - im.size[1]) // 2)) + im = new_im + # resize to 1024x1024 + im = im.resize((1024, 1024), Image.LANCZOS) + im.save(fp) + +def run_multiview_infer(dataloader, pipeline, cfg: TestConfig, save_dir, num_levels=3): + if cfg.seed is None: + generator = None + else: + generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed) + + images_cond = [] + for _, batch in tqdm(enumerate(dataloader)): + torch.cuda.empty_cache() + images_cond.append(batch['image_cond_rgb'][:, 0].cuda()) + imgs_in = torch.cat([batch['image_cond_rgb']]*2, dim=0).cuda() + num_views = imgs_in.shape[1] + imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W) + + target_h, target_w = imgs_in.shape[-2], imgs_in.shape[-1] + + normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'].cuda(), batch['color_prompt_embeddings'].cuda() + prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0) + prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") + + # B*Nv images + unet_out = pipeline( + imgs_in, None, prompt_embeds=prompt_embeddings, + generator=generator, guidance_scale=3.0, output_type='pt', num_images_per_prompt=1, + height=cfg.height, width=cfg.width, + num_inference_steps=40, eta=1.0, + num_levels=num_levels, + ) + + for level in range(num_levels): + out = unet_out[level].images + bsz = out.shape[0] // 2 + + normals_pred = out[:bsz] + images_pred = out[bsz:] + + cur_dir = save_dir + os.makedirs(cur_dir, exist_ok=True) + + for i in range(bsz//num_views): + scene = batch['filename'][i].split('.')[0] + scene_dir = os.path.join(cur_dir, scene, f'level{level}') + os.makedirs(scene_dir, exist_ok=True) + + img_in_ = images_cond[-1][i].to(out.device) + for j in range(num_views): + view = VIEWS[j] + idx = i*num_views + j + normal = normals_pred[idx] + color = images_pred[idx] + + ## save color and normal--------------------- + normal_filename = f"normal_{j}.png" + rgb_filename = f"color_{j}.png" + save_image(normal, os.path.join(scene_dir, normal_filename)) + save_image(color, os.path.join(scene_dir, rgb_filename)) + + torch.cuda.empty_cache() + +def load_multiview_pipeline(cfg): + pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained( + cfg.pretrained_path, + torch_dtype=torch.float16,) + pipeline.unet.enable_xformers_memory_efficient_attention() + if torch.cuda.is_available(): + pipeline.to(device) + return pipeline + +def main( + cfg: TestConfig +): + set_seed(cfg.seed) + pipeline = load_multiview_pipeline(cfg) + if torch.cuda.is_available(): + pipeline.to(device) + + image_transforms = [transforms.Resize(int(max(cfg.height, cfg.width))), + transforms.CenterCrop((cfg.height, cfg.width)), + transforms.ToTensor(), + transforms.Lambda(lambda x: x * 2. - 1), + ] + image_transforms = transforms.Compose(image_transforms) + dataset = SingleImageData(image_transforms=image_transforms, input_dir=cfg.input_dir, total_views=cfg.num_views) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=1, shuffle=False, num_workers=1 + ) + os.makedirs(cfg.output_dir, exist_ok=True) + + with torch.no_grad(): + run_multiview_infer(dataloader, pipeline, cfg, cfg.output_dir, num_levels=cfg.num_levels) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--num_views", type=int, default=6) + parser.add_argument("--num_levels", type=int, default=3) + parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024') + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=576) + parser.add_argument("--input_dir", type=str, default='./result/apose') + parser.add_argument("--output_dir", type=str, default='./result/multiview') + cfg = parser.parse_args() + + if cfg.num_views == 6: + VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] + else: + raise NotImplementedError(f"Number of views {cfg.num_views} not supported") + main(cfg) diff --git a/infer_refine.py b/infer_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..3c40f6ad4e0483fd70890a44550d8f168852f5a8 --- /dev/null +++ b/infer_refine.py @@ -0,0 +1,353 @@ +import cv2 +import numpy as np +import os +import trimesh +import argparse +import torch +import scipy +from PIL import Image + +from refine.mesh_refine import geo_refine +from refine.func import make_star_cameras_orthographic +from refine.render import NormalsRenderer, calc_vertex_normals + +from pytorch3d.structures import Meshes +from sklearn.neighbors import KDTree + +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry + +sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda() +generator = SamAutomaticMaskGenerator( + model=sam, + points_per_side=64, + pred_iou_thresh=0.80, + stability_score_thresh=0.92, + crop_n_layers=1, + crop_n_points_downscale_factor=2, + min_mask_region_area=100, +) + + +def fix_vert_color_glb(mesh_path): + from pygltflib import GLTF2, Material, PbrMetallicRoughness + obj1 = GLTF2().load(mesh_path) + obj1.meshes[0].primitives[0].material = 0 + obj1.materials.append(Material( + pbrMetallicRoughness = PbrMetallicRoughness( + baseColorFactor = [1.0, 1.0, 1.0, 1.0], + metallicFactor = 0., + roughnessFactor = 1.0, + ), + emissiveFactor = [0.0, 0.0, 0.0], + doubleSided = True, + )) + obj1.save(mesh_path) + + +def srgb_to_linear(c_srgb): + c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4) + return c_linear.clip(0, 1.) + + +def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True): + # convert from pytorch3d meshes to trimesh mesh + vertices = meshes.verts_packed().cpu().float().numpy() + triangles = meshes.faces_packed().cpu().long().numpy() + np_color = meshes.textures.verts_features_packed().cpu().float().numpy() + if save_glb_path.endswith(".glb"): + # rotate 180 along +Y + vertices[:, [0, 2]] = -vertices[:, [0, 2]] + + if apply_sRGB_to_LinearRGB: + np_color = srgb_to_linear(np_color) + assert vertices.shape[0] == np_color.shape[0] + assert np_color.shape[1] == 3 + assert 0 <= np_color.min() and np_color.max() <= 1.001, f"min={np_color.min()}, max={np_color.max()}" + np_color = np.clip(np_color, 0, 1) + mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color) + mesh.remove_unreferenced_vertices() + # save mesh + mesh.export(save_glb_path) + if save_glb_path.endswith(".glb"): + fix_vert_color_glb(save_glb_path) + print(f"saving to {save_glb_path}") + + +def calc_horizontal_offset(target_img, source_img): + target_mask = target_img.astype(np.float32).sum(axis=-1) > 750 + source_mask = source_img.astype(np.float32).sum(axis=-1) > 750 + best_offset = -114514 + for offset in range(-200, 200): + offset_mask = np.roll(source_mask, offset, axis=1) + overlap = (target_mask & offset_mask).sum() + if overlap > best_offset: + best_offset = overlap + best_offset_value = offset + return best_offset_value + + +def calc_horizontal_offset2(target_mask, source_img): + source_mask = source_img.astype(np.float32).sum(axis=-1) > 750 + best_offset = -114514 + for offset in range(-200, 200): + offset_mask = np.roll(source_mask, offset, axis=1) + overlap = (target_mask & offset_mask).sum() + if overlap > best_offset: + best_offset = overlap + best_offset_value = offset + return best_offset_value + + +def get_distract_mask(color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20): + distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres + if normal_0 is not None and normal_1 is not None: + distract_area |= np.abs(normal_0 - normal_1).sum(axis=-1) > thres + labeled_array, num_features = scipy.ndimage.label(distract_area) + results = [] + + random_sampled_points = [] + + for i in range(num_features + 1): + if np.sum(labeled_array == i) > 1000 and np.sum(labeled_array == i) < 100000: + results.append((i, np.sum(labeled_array == i))) + # random sample a point in the area + points = np.argwhere(labeled_array == i) + random_sampled_points.append(points[np.random.randint(0, points.shape[0])]) + + results = sorted(results, key=lambda x: x[1], reverse=True) # [1:] + distract_mask = np.zeros_like(distract_area) + distract_bbox = np.zeros_like(distract_area) + for i, _ in results: + distract_mask |= labeled_array == i + bbox = np.argwhere(labeled_array == i) + min_x, min_y = bbox.min(axis=0) + max_x, max_y = bbox.max(axis=0) + distract_bbox[min_x:max_x, min_y:max_y] = 1 + + points = np.array(random_sampled_points)[:, ::-1] + labels = np.ones(len(points), dtype=np.int32) + + masks = generator.generate((color_1 * 255).astype(np.uint8)) + + outside_area = np.abs(color_0 - color_1).sum(axis=-1) < outside_thres + + final_mask = np.zeros_like(distract_mask) + for iii, mask in enumerate(masks): + mask['segmentation'] = cv2.resize(mask['segmentation'].astype(np.float32), (1024, 1024)) > 0.5 + intersection = np.logical_and(mask['segmentation'], distract_mask).sum() + total = mask['segmentation'].sum() + iou = intersection / total + outside_intersection = np.logical_and(mask['segmentation'], outside_area).sum() + outside_total = mask['segmentation'].sum() + outside_iou = outside_intersection / outside_total + if iou > ratio and outside_iou < outside_ratio: + final_mask |= mask['segmentation'] + + # calculate coverage + intersection = np.logical_and(final_mask, distract_mask).sum() + total = distract_mask.sum() + coverage = intersection / total + + if coverage < 0.8: + # use original distract mask + final_mask = (distract_mask.copy() * 255).astype(np.uint8) + final_mask = cv2.dilate(final_mask, np.ones((3, 3), np.uint8), iterations=3) + labeled_array_dilate, num_features_dilate = scipy.ndimage.label(final_mask) + for i in range(num_features_dilate + 1): + if np.sum(labeled_array_dilate == i) < 200: + final_mask[labeled_array_dilate == i] = 255 + + final_mask = cv2.erode(final_mask, np.ones((3, 3), np.uint8), iterations=3) + final_mask = final_mask > 127 + + return distract_mask, distract_bbox, random_sampled_points, final_mask + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input_mv_dir', type=str, default='result/multiview') + parser.add_argument('--input_obj_dir', type=str, default='result/slrm') + parser.add_argument('--output_dir', type=str, default='result/refined') + parser.add_argument('--outside_ratio', type=float, default=0.20) + parser.add_argument('--no_decompose', action='store_true') + args = parser.parse_args() + + for test_idx in os.listdir(args.input_mv_dir): + mv_root_dir = os.path.join(args.input_mv_dir, test_idx) + obj_dir = os.path.join(args.input_obj_dir, test_idx) + + fixed_v, fixed_f = None, None + flow_vert, flow_vector = None, None + last_colors, last_normals = None, None + last_front_color, last_front_normal = None, None + distract_mask = None + + mv, proj = make_star_cameras_orthographic(8, 1, r=1.2) + mv = mv[[4, 3, 2, 0, 6, 5]] + renderer = NormalsRenderer(mv,proj,(1024,1024)) + + if not args.no_decompose: + for name_idx, level in zip([3, 1, 2], [2, 1, 0]): + mesh = trimesh.load(obj_dir + f'_{name_idx}.obj') + new_mesh = mesh.split(only_watertight=False) + new_mesh = [ j for j in new_mesh if len(j.vertices) >= 300 ] + mesh = trimesh.Scene(new_mesh).dump(concatenate=True) + mesh_v, mesh_f = mesh.vertices, mesh.faces + + if last_colors is None: + images = renderer.render( + torch.tensor(mesh_v, device='cuda').float(), + torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(), + torch.tensor(mesh_f, device='cuda'), + ) + mask = (images[..., 3] < 0.9).cpu().numpy() + + colors, normals = [], [] + for i in range(6): + color_path = os.path.join(mv_root_dir, f'level{level}', f'color_{i}.png') + normal_path = os.path.join(mv_root_dir, f'level{level}', f'normal_{i}.png') + color = cv2.imread(color_path) + normal = cv2.imread(normal_path) + color = color[..., ::-1] + normal = normal[..., ::-1] + + if last_colors is not None: + offset = calc_horizontal_offset(np.array(last_colors[i]), color) + # print('offset', i, offset) + else: + offset = calc_horizontal_offset2(mask[i], color) + # print('init offset', i, offset) + + if offset != 0: + color = np.roll(color, offset, axis=1) + normal = np.roll(normal, offset, axis=1) + + color = Image.fromarray(color) + normal = Image.fromarray(normal) + colors.append(color) + normals.append(normal) + + if last_front_color is not None and level == 0: + original_mask, distract_bbox, _, distract_mask = get_distract_mask(last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=args.outside_ratio) + cv2.imwrite(f'{args.output_dir}/{test_idx}/distract_mask.png', distract_mask.astype(np.uint8) * 255) + else: + distract_mask = None + distract_bbox = None + + last_front_color = np.array(colors[0]).astype(np.float32) / 255.0 + last_front_normal = np.array(normals[0]).astype(np.float32) / 255.0 + + if last_colors is None: + from copy import deepcopy + last_colors, last_normals = deepcopy(colors), deepcopy(normals) + + # my mesh flow weight by nearest vertexs + if fixed_v is not None and fixed_f is not None and level == 1: + t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f) + + fixed_v_cpu = fixed_v.cpu().numpy() + kdtree_anchor = KDTree(fixed_v_cpu) + kdtree_mesh_v = KDTree(mesh_v) + _, idx_anchor = kdtree_anchor.query(mesh_v, k=1) + _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25) + idx_anchor = idx_anchor.squeeze() + neighbors = torch.tensor(mesh_v).cuda()[idx_mesh_v] # V, 25, 3 + # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25] + neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v).cuda()[:, None], dim=-1) + neighbor_dists[neighbor_dists > 0.06] = 114514. + neighbor_weights = torch.exp(-neighbor_dists * 1.) + neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True) + anchors = fixed_v[idx_anchor] # V, 3 + anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3 + dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01 + vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3 + vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3 + weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3 + mesh_v += weighted_vec_anchor.cpu().numpy() + + t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f) + + mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32) + mesh_f = torch.tensor(mesh_f, device='cuda') + + new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox) + + # my mesh flow weight by nearest vertexs + try: + if fixed_v is not None and fixed_f is not None and level != 0: + new_mesh_v = new_mesh.verts_packed().cpu().numpy() + + fixed_v_cpu = fixed_v.cpu().numpy() + kdtree_anchor = KDTree(fixed_v_cpu) + kdtree_mesh_v = KDTree(new_mesh_v) + _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1) + _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25) + idx_anchor = idx_anchor.squeeze() + neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3 + # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25] + neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1) + neighbor_dists[neighbor_dists > 0.06] = 114514. + neighbor_weights = torch.exp(-neighbor_dists * 1.) + neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True) + anchors = fixed_v[idx_anchor] # V, 3 + anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3 + dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01 + vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3 + vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3 + weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3 + new_mesh_v += weighted_vec_anchor.cpu().numpy() + + # replace new_mesh verts with new_mesh_v + new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures) + + except Exception as e: + pass + + os.makedirs(f'{args.output_dir}/{test_idx}', exist_ok=True) + save_py3dmesh_with_trimesh_fast(new_mesh, f'{args.output_dir}/{test_idx}/out_{level}.glb', apply_sRGB_to_LinearRGB=False) + + if fixed_v is None: + fixed_v, fixed_f = simp_v, simp_f + else: + fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0) + fixed_v = torch.cat([fixed_v, simp_v], dim=0) + + + else: + mesh = trimesh.load(obj_dir + f'_0.obj') + mesh_v, mesh_f = mesh.vertices, mesh.faces + + images = renderer.render( + torch.tensor(mesh_v, device='cuda').float(), + torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(), + torch.tensor(mesh_f, device='cuda'), + ) + mask = (images[..., 3] < 0.9).cpu().numpy() + + colors, normals = [], [] + for i in range(6): + color_path = os.path.join(mv_root_dir, f'level0', f'color_{i}.png') + normal_path = os.path.join(mv_root_dir, f'level0', f'normal_{i}.png') + color = cv2.imread(color_path) + normal = cv2.imread(normal_path) + color = color[..., ::-1] + normal = normal[..., ::-1] + + offset = calc_horizontal_offset2(mask[i], color) + + if offset != 0: + color = np.roll(color, offset, axis=1) + normal = np.roll(normal, offset, axis=1) + + color = Image.fromarray(color) + normal = Image.fromarray(normal) + colors.append(color) + normals.append(normal) + + mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32) + mesh_f = torch.tensor(mesh_f, device='cuda') + + new_mesh, _, _ = geo_refine(mesh_v, mesh_f, colors, normals, no_decompose=True, expansion_weight=0.) + + os.makedirs(f'{args.output_dir}/{test_idx}', exist_ok=True) + save_py3dmesh_with_trimesh_fast(new_mesh, f'{args.output_dir}/{test_idx}/out_nodecomp.glb', apply_sRGB_to_LinearRGB=False) diff --git a/infer_slrm.py b/infer_slrm.py new file mode 100644 index 0000000000000000000000000000000000000000..6447ca305f9bb3ef8be9de9e728527cc10eb7efe --- /dev/null +++ b/infer_slrm.py @@ -0,0 +1,199 @@ +import os +import imageio +import numpy as np +import torch +import cv2 +import glob +import matplotlib.pyplot as plt +from PIL import Image +from torchvision.transforms import v2 +from pytorch_lightning import seed_everything +from omegaconf import OmegaConf +from tqdm import tqdm + +from slrm.utils.train_util import instantiate_from_config +from slrm.utils.camera_util import ( + FOV_to_intrinsics, + get_circular_camera_poses, +) +from slrm.utils.mesh_util import save_obj, save_glb +from slrm.utils.infer_util import images_to_video + +from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): + """ + Get the rendering camera parameters. + """ + c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation) + if is_flexicubes: + cameras = torch.linalg.inv(c2ws) + cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1) + else: + extrinsics = c2ws.flatten(-2) + intrinsics = FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2) + cameras = torch.cat([extrinsics, intrinsics], dim=-1) + cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1) + return cameras + + +def images_to_video(images, output_dir, fps=30): + # images: (N, C, H, W) + os.makedirs(os.path.dirname(output_dir), exist_ok=True) + frames = [] + for i in range(images.shape[0]): + frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255) + assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ + f"Frame shape mismatch: {frame.shape} vs {images.shape}" + assert frame.min() >= 0 and frame.max() <= 255, \ + f"Frame value out of range: {frame.min()} ~ {frame.max()}" + frames.append(frame) + imageio.mimwrite(output_dir, np.stack(frames), fps=fps, codec='h264') + + +############################################################################### +# Configuration. +############################################################################### + +seed_everything(0) + +config_path = 'configs/mesh-slrm-infer.yaml' +config = OmegaConf.load(config_path) +config_name = os.path.basename(config_path).replace('.yaml', '') +model_config = config.model_config +infer_config = config.infer_config + +IS_FLEXICUBES = True if config_name.startswith('mesh') else False + +device = torch.device('cuda') + +# load reconstruction model +print('Loading reconstruction model ...') +model = instantiate_from_config(model_config) +state_dict = torch.load(infer_config.model_path, map_location='cpu') +model.load_state_dict(state_dict, strict=False) + +model = model.to(device) +if IS_FLEXICUBES: + model.init_flexicubes_geometry(device, fovy=30.0, is_ortho=model.is_ortho) +model = model.eval() + +print('Loading Finished!') + +def make_mesh(mesh_fpath, planes, level=None): + + mesh_basename = os.path.basename(mesh_fpath).split('.')[0] + mesh_dirname = os.path.dirname(mesh_fpath) + mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb") + + with torch.no_grad(): + # get mesh + mesh_out = model.extract_mesh( + planes, + use_texture_map=False, + levels=torch.tensor([level]).to(device), + **infer_config, + ) + + vertices, faces, vertex_colors = mesh_out + vertices = vertices[:, [1, 2, 0]] + + save_glb(vertices, faces, vertex_colors, mesh_glb_fpath) + save_obj(vertices, faces, vertex_colors, mesh_fpath) + + return mesh_fpath, mesh_glb_fpath + + +def make3d(images, name, output_dir): + input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device) + + render_cameras = get_render_cameras( + batch_size=1, radius=4.5, elevation=20.0, is_flexicubes=IS_FLEXICUBES).to(device) + + images = images.unsqueeze(0).to(device) + images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) + + mesh_fpath = os.path.join(output_dir, f"{name}.obj") + + mesh_basename = os.path.basename(mesh_fpath).split('.')[0] + mesh_dirname = os.path.dirname(mesh_fpath) + video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4") + + with torch.no_grad(): + # get triplane + planes = model.forward_planes(images, input_cameras.float()) + + # get video + chunk_size = 20 if IS_FLEXICUBES else 1 + render_size = 512 + + frames = [ [] for _ in range(4) ] + for i in tqdm(range(0, render_cameras.shape[1], chunk_size)): + if IS_FLEXICUBES: + frame = model.forward_geometry_separate( + planes, + render_cameras[:, i:i+chunk_size], + render_size=render_size, + levels=torch.tensor([0]).to(device), + )['imgs'] + for j in range(4): + frames[j].append(frame[j]) + else: + frame = model.synthesizer( + planes, + cameras=render_cameras[:, i:i+chunk_size], + render_size=render_size, + )['images_rgb'] + frames.append(frame) + + for j in range(4): + frames[j] = torch.cat(frames[j], dim=1) + video_fpath_j = video_fpath.replace('.mp4', f'_{j}.mp4') + images_to_video( + frames[j][0], + video_fpath_j, + fps=30, + ) + + _, mesh_glb_fpath = make_mesh(mesh_fpath.replace(mesh_fpath[-4:], f'_{j}{mesh_fpath[-4:]}'), planes, level=[0, 3, 4, 2][j]) + + return video_fpath, mesh_fpath, mesh_glb_fpath + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', type=str, default="result/multiview") + parser.add_argument('--output_dir', type=str, default="result/slrm") + args = parser.parse_args() + + paths = glob.glob(args.input_dir + '/*') + os.makedirs(args.output_dir, exist_ok=True) + + def load_rgb(path): + img = plt.imread(path) + img = Image.fromarray(np.uint8(img * 255.)) + return img + + for path in tqdm(paths): + name = path.split('/')[-1] + index_targets = [ + 'level0/color_0.png', + 'level0/color_1.png', + 'level0/color_2.png', + 'level0/color_3.png', + 'level0/color_4.png', + 'level0/color_5.png', + ] + imgs = [] + for index_target in index_targets: + img = load_rgb(os.path.join(path, index_target)) + imgs.append(img) + + imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0 + imgs = torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2).contiguous().float() # (6, 3, 1024, 1024) + + video_fpath, mesh_fpath, mesh_glb_fpath = make3d(imgs, name, args.output_dir) + diff --git a/input_cases/1.png b/input_cases/1.png new file mode 100644 index 0000000000000000000000000000000000000000..9abaaecbe13f5c7647a0b409b448d881be6ab910 Binary files /dev/null and b/input_cases/1.png differ diff --git a/input_cases/2.png b/input_cases/2.png new file mode 100644 index 0000000000000000000000000000000000000000..511cac94662b2a1e4cf43d8e2cbad84d039337ef Binary files /dev/null and b/input_cases/2.png differ diff --git a/input_cases/3.png b/input_cases/3.png new file mode 100644 index 0000000000000000000000000000000000000000..26fdb01dbd8e5de37e08f51adf68e307d4fe936e Binary files /dev/null and b/input_cases/3.png differ diff --git a/input_cases/4.png b/input_cases/4.png new file mode 100644 index 0000000000000000000000000000000000000000..826855e7b8371a8dd07a4d010597ff320bd0cfbd Binary files /dev/null and b/input_cases/4.png differ diff --git a/input_cases/ayaka.png b/input_cases/ayaka.png new file mode 100755 index 0000000000000000000000000000000000000000..36290b287174c2ba775c0191279f16511012e965 Binary files /dev/null and b/input_cases/ayaka.png differ diff --git a/input_cases/firefly2.png b/input_cases/firefly2.png new file mode 100644 index 0000000000000000000000000000000000000000..44628109b1602c5cd20d4ca29d412ff5648c4f3b Binary files /dev/null and b/input_cases/firefly2.png differ diff --git a/input_cases_apose/1.png b/input_cases_apose/1.png new file mode 100644 index 0000000000000000000000000000000000000000..dc9b2d261493010e49673b97b36519dd5b5c853b Binary files /dev/null and b/input_cases_apose/1.png differ diff --git a/input_cases_apose/2.png b/input_cases_apose/2.png new file mode 100644 index 0000000000000000000000000000000000000000..e75d10454d0420109cfd4a5ee2ea4dad3d1db60e Binary files /dev/null and b/input_cases_apose/2.png differ diff --git a/input_cases_apose/3.png b/input_cases_apose/3.png new file mode 100644 index 0000000000000000000000000000000000000000..cc57d106e83e82e90bce4720d11cdd1bf6b705e4 Binary files /dev/null and b/input_cases_apose/3.png differ diff --git a/input_cases_apose/4.png b/input_cases_apose/4.png new file mode 100644 index 0000000000000000000000000000000000000000..a8eb4d85a7a337b4705a826e15843f240ea284e8 Binary files /dev/null and b/input_cases_apose/4.png differ diff --git a/input_cases_apose/ayaka.png b/input_cases_apose/ayaka.png new file mode 100644 index 0000000000000000000000000000000000000000..eeb837d114c1174976fdd4dee1797c335c7848fa Binary files /dev/null and b/input_cases_apose/ayaka.png differ diff --git a/input_cases_apose/belle.png b/input_cases_apose/belle.png new file mode 100644 index 0000000000000000000000000000000000000000..f6b43fa9a3d512c06f0f5c0c0ebd3cecdbbd87dc Binary files /dev/null and b/input_cases_apose/belle.png differ diff --git a/input_cases_apose/firefly.png b/input_cases_apose/firefly.png new file mode 100644 index 0000000000000000000000000000000000000000..93a858cc0a04dfcc7fbf01d87af4a53bc69127d6 Binary files /dev/null and b/input_cases_apose/firefly.png differ diff --git a/multiview/__init__.py b/multiview/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/multiview/fixed_prompt_embeds_6view/clr_embeds.pt b/multiview/fixed_prompt_embeds_6view/clr_embeds.pt new file mode 100755 index 0000000000000000000000000000000000000000..de105d6a0a017da97af4644608a87785ec54d9cb --- /dev/null +++ b/multiview/fixed_prompt_embeds_6view/clr_embeds.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9e51666588d0f075e031262744d371e12076160231aab19a531dbf7ab976e4d +size 946932 diff --git a/multiview/fixed_prompt_embeds_6view/normal_embeds.pt b/multiview/fixed_prompt_embeds_6view/normal_embeds.pt new file mode 100755 index 0000000000000000000000000000000000000000..7fb88dcf24443b235588cf426eba3951316e825f --- /dev/null +++ b/multiview/fixed_prompt_embeds_6view/normal_embeds.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53dfcd17f62fbfd8aeba60b1b05fa7559d72179738fd048e2ac1d53e5be5ed9d +size 946941 diff --git a/multiview/models/transformer_mv2d_image.py b/multiview/models/transformer_mv2d_image.py new file mode 100755 index 0000000000000000000000000000000000000000..15699b51f3ada18dd39e51fb9a4d70a1a62496dc --- /dev/null +++ b/multiview/models/transformer_mv2d_image.py @@ -0,0 +1,995 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.import_utils import is_xformers_available + +from einops import rearrange, repeat +import pdb +import random + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +def my_repeat(tensor, num_repeats): + """ + Repeat a tensor along a given dimension + """ + if len(tensor.shape) == 3: + return repeat(tensor, "b d c -> (b v) d c", v=num_repeats) + elif len(tensor.shape) == 4: + return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats) + + +@dataclass +class TransformerMV2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class TransformerMV2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + num_views: int = 1, + cd_attention_last: bool=False, + cd_attention_mid: bool=False, + multiview_attention: bool=True, + sparse_mv_attention: bool = False, + mvcd_attention: bool=False + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + else: + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicMVTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + else: + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return TransformerMV2DModelOutput(sample=output) + + +@maybe_allow_in_graph +class BasicMVTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + mvcd_attention: bool = False + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + self.multiview_attention = multiview_attention + self.sparse_mv_attention = sparse_mv_attention + self.mvcd_attention = mvcd_attention + + self.attn1 = CustomAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=MVAttnProcessor() + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + self.num_views = num_views + + self.cd_attention_last = cd_attention_last + + if self.cd_attention_last: + # Joint task -Attn + self.attn_joint_last = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data) + self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + + self.cd_attention_mid = cd_attention_mid + + if self.cd_attention_mid: + print("cross-domain attn in the middle") + # Joint task -Attn + self.attn_joint_mid = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data) + self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + assert attention_mask is None # not supported yet + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + num_views=self.num_views, + multiview_attention=self.multiview_attention, + sparse_mv_attention=self.sparse_mv_attention, + mvcd_attention=self.mvcd_attention, + **cross_attention_kwargs, + ) + + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # joint attention twice + if self.cd_attention_mid: + norm_hidden_states = ( + self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states) + ) + hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + if self.cd_attention_last: + norm_hidden_states = ( + self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states) + ) + hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states + + return hidden_states + + +class CustomAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersMVAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + + +class CustomJointAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersJointAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + +class MVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1, + multiview_attention=True + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # multi-view self-attention + if multiview_attention: + if num_views <= 6: + # after use xformer; possible to train with 6 views + key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + else: # apply sparse attention + raise NotImplementedError("sparse attention not implemented yet.") + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class XFormersMVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1., + multiview_attention=True, + sparse_mv_attention=False, + mvcd_attention=False, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key_raw = attn.to_k(encoder_hidden_states) + value_raw = attn.to_v(encoder_hidden_states) + + # multi-view self-attention + if multiview_attention: + if not sparse_mv_attention: + key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views) + value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views) + else: + key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c] + value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) + key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c + value = torch.cat([value_front, value_raw], dim=1) + + if mvcd_attention: + # memory efficient, cross domain attention + key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2) + key_cross = torch.concat([key_1, key_0], dim=0) + value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c + key = torch.cat([key, key_cross], dim=1) + value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c + else: + # print("don't use multiview attention.") + key = key_raw + value = value_raw + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + + +class XFormersJointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value, dim=0, chunks=2) + key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c + value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c + key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c + value = torch.cat([value]*2, dim=0) # (2 b t) 2d c + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class JointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value, dim=0, chunks=2) + key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c + value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c + key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c + value = torch.cat([value]*2, dim=0) # (2 b t) 2d c + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + diff --git a/multiview/models/transformer_mv2d_rowwise.py b/multiview/models/transformer_mv2d_rowwise.py new file mode 100755 index 0000000000000000000000000000000000000000..87efbe481be99e2a2a9936a983695b32c7051543 --- /dev/null +++ b/multiview/models/transformer_mv2d_rowwise.py @@ -0,0 +1,972 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.import_utils import is_xformers_available + +from einops import rearrange +import pdb +import random +import math + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@dataclass +class TransformerMV2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class TransformerMV2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + num_views: int = 1, + cd_attention_last: bool=False, + cd_attention_mid: bool=False, + multiview_attention: bool=True, + sparse_mv_attention: bool = True, # not used + mvcd_attention: bool=False + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + else: + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicMVTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + mvcd_attention=mvcd_attention + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + else: + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return TransformerMV2DModelOutput(sample=output) + + +@maybe_allow_in_graph +class BasicMVTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + mvcd_attention: bool = False, + rowwise_attention: bool = True + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + self.multiview_attention = multiview_attention + self.mvcd_attention = mvcd_attention + self.rowwise_attention = multiview_attention and rowwise_attention + + # rowwise multiview attention + + print('INFO: using row wise attention...') + + self.attn1 = CustomAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=MVAttnProcessor() + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + self.num_views = num_views + + self.cd_attention_last = cd_attention_last + + if self.cd_attention_last: + # Joint task -Attn + self.attn_joint = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + nn.init.zeros_(self.attn_joint.to_out[0].weight.data) + self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + + self.cd_attention_mid = cd_attention_mid + + if self.cd_attention_mid: + print("joint twice") + # Joint task -Attn + self.attn_joint_twice = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data) + self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + assert attention_mask is None # not supported yet + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + multiview_attention=self.multiview_attention, + mvcd_attention=self.mvcd_attention, + num_views=self.num_views, + **cross_attention_kwargs, + ) + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # joint attention twice + if self.cd_attention_mid: + norm_hidden_states = ( + self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states) + ) + hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + if self.cd_attention_last: + norm_hidden_states = ( + self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states) + ) + hidden_states = self.attn_joint(norm_hidden_states) + hidden_states + + return hidden_states + + +class CustomAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersMVAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + + +class CustomJointAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersJointAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + +class MVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1, + multiview_attention=True + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + height = int(math.sqrt(sequence_length)) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # multi-view self-attention + key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class XFormersMVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1, + multiview_attention=True, + mvcd_attention=False, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + height = int(math.sqrt(sequence_length)) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + print('Warning: using group norm, pay attention to use it in row-wise attention') + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key_raw = attn.to_k(encoder_hidden_states) + value_raw = attn.to_v(encoder_hidden_states) + + key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) + if mvcd_attention: + # memory efficient, cross domain attention + key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2) + key_cross = torch.concat([key_1, key_0], dim=0) + value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c + key = torch.cat([key, key_cross], dim=1) + value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c + + + query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64]) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + # print(hidden_states.shape) + hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class XFormersJointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value, dim=0, chunks=2) + key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c + value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c + key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c + value = torch.cat([value]*2, dim=0) # (2 b t) 2d c + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class JointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value, dim=0, chunks=2) + key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c + value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c + key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c + value = torch.cat([value]*2, dim=0) # (2 b t) 2d c + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/multiview/models/transformer_mv2d_self_rowwise.py b/multiview/models/transformer_mv2d_self_rowwise.py new file mode 100755 index 0000000000000000000000000000000000000000..3f728ad904bee17be8cec32905eb1aa665488f26 --- /dev/null +++ b/multiview/models/transformer_mv2d_self_rowwise.py @@ -0,0 +1,1042 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.import_utils import is_xformers_available + +from einops import rearrange +import pdb +import random +import math + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@dataclass +class TransformerMV2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class TransformerMV2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + num_views: int = 1, + cd_attention_mid: bool=False, + cd_attention_last: bool=False, + multiview_attention: bool=True, + sparse_mv_attention: bool = True, # not used + mvcd_attention: bool=False, + use_dino: bool=False + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + else: + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicMVTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + else: + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + dino_feature: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + hw_ratio: Optional[torch.FloatTensor] = 1.5, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + dino_feature=dino_feature, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + hw_ratio=hw_ratio, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return TransformerMV2DModelOutput(sample=output) + + +@maybe_allow_in_graph +class BasicMVTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + mvcd_attention: bool = False, + rowwise_attention: bool = True, + use_dino: bool = False + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + self.multiview_attention = multiview_attention + self.mvcd_attention = mvcd_attention + self.cd_attention_mid = cd_attention_mid + self.rowwise_attention = multiview_attention and rowwise_attention + + if mvcd_attention and (not cd_attention_mid): + # add cross domain attn to self attn + self.attn1 = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + else: + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention + ) + # 1.1 rowwise multiview attention + if self.rowwise_attention: + # print('INFO: using self+row_wise mv attention...') + self.norm_mv = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn_mv = CustomAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=MVAttnProcessor() + ) + nn.init.zeros_(self.attn_mv.to_out[0].weight.data) + else: + self.norm_mv = None + self.attn_mv = None + + # # 1.2 rowwise cross-domain attn + # if mvcd_attention: + # self.attn_joint = CustomJointAttention( + # query_dim=dim, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # cross_attention_dim=cross_attention_dim if only_cross_attention else None, + # upcast_attention=upcast_attention, + # processor=JointAttnProcessor() + # ) + # nn.init.zeros_(self.attn_joint.to_out[0].weight.data) + # self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + # else: + # self.attn_joint = None + # self.norm_joint = None + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + self.num_views = num_views + + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + dino_feature: Optional[torch.FloatTensor] = None, + hw_ratio: Optional[torch.FloatTensor] = 1.5, + ): + assert attention_mask is None # not supported yet + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + # multiview_attention=self.multiview_attention, + # mvcd_attention=self.mvcd_attention, + **cross_attention_kwargs, + ) + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # import pdb;pdb.set_trace() + # 1.1 row wise multiview attention + if self.rowwise_attention: + norm_hidden_states = ( + self.norm_mv(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_mv(hidden_states) + ) + attn_output = self.attn_mv( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + num_views=self.num_views, + multiview_attention=self.multiview_attention, + cd_attention_mid=self.cd_attention_mid, + hw_ratio=hw_ratio, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class CustomAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersMVAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + + +class CustomJointAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersJointAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + +class MVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1, + cd_attention_mid=False + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + height = int(math.sqrt(sequence_length)) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # print('query', query.shape, 'key', key.shape, 'value', value.shape) + #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) + # pdb.set_trace() + # multi-view self-attention + def transpose(tensor): + tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) + tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c + tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c + tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) + return tensor + + if cd_attention_mid: + key = transpose(key) + value = transpose(value) + query = transpose(query) + else: + key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if cd_attention_mid: + hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) + hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c + hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c + hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) + else: + hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class XFormersMVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1, + multiview_attention=True, + cd_attention_mid=False, + hw_ratio=1.5 + ): + # print(num_views) + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + height = int(math.sqrt(sequence_length*hw_ratio)) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + print('Warning: using group norm, pay attention to use it in row-wise attention') + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key_raw = attn.to_k(encoder_hidden_states) + value_raw = attn.to_v(encoder_hidden_states) + + # print('query', query.shape, 'key', key.shape, 'value', value.shape) + # pdb.set_trace() + def transpose(tensor): + tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) + tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c + tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c + tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) + return tensor + # print(mvcd_attention) + # import pdb;pdb.set_trace() + if cd_attention_mid: + key = transpose(key_raw) + value = transpose(value_raw) + query = transpose(query) + else: + key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) + query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) + + + query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64]) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if cd_attention_mid: + hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) + hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c + hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c + hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) + else: + hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class XFormersJointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + def transpose(tensor): + tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c + tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c + return tensor + key = transpose(key) + value = transpose(value) + query = transpose(query) + # from icecream import ic + # ic(key.shape, value.shape, query.shape) + # import pdb;pdb.set_trace() + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2) + hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class JointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + def transpose(tensor): + tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c + tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c + return tensor + key = transpose(key) + value = transpose(value) + query = transpose(query) + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/multiview/models/unet_mv2d_blocks.py b/multiview/models/unet_mv2d_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..9f0ee5afa33cc85246b045b2de2a70dff3220d6d --- /dev/null +++ b/multiview/models/unet_mv2d_blocks.py @@ -0,0 +1,980 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import is_torch_version, logging +from diffusers.models.normalization import AdaGroupNorm +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from diffusers.models.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D + +from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D +from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, + num_views=1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool=False, + use_dino: bool = False +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + # custom MV2D attention block + elif down_block_type == "CrossAttnDownBlockMV2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D") + return CrossAttnDownBlockMV2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + selfattn_block=selfattn_block, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, + num_views=1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool=False, + use_dino: bool = False +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + # custom MV2D attention block + elif up_block_type == "CrossAttnUpBlockMV2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D") + return CrossAttnUpBlockMV2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + selfattn_block=selfattn_block, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlockMV2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool=False, + use_dino: bool = False + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + if selfattn_block == "custom": + from .transformer_mv2d import TransformerMV2DModel + elif selfattn_block == "rowwise": + from .transformer_mv2d_rowwise import TransformerMV2DModel + elif selfattn_block == "self_rowwise": + from .transformer_mv2d_self_rowwise import TransformerMV2DModel + else: + raise NotImplementedError + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + ) + else: + raise NotImplementedError + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + dino_feature: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + hw_ratio = hidden_states.size(2) / hidden_states.size(3) + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + dino_feature=dino_feature, + return_dict=False, + hw_ratio=hw_ratio, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlockMV2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool=False, + use_dino: bool = False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if selfattn_block == "custom": + from .transformer_mv2d import TransformerMV2DModel + elif selfattn_block == "rowwise": + from .transformer_mv2d_rowwise import TransformerMV2DModel + elif selfattn_block == "self_rowwise": + from .transformer_mv2d_self_rowwise import TransformerMV2DModel + else: + raise NotImplementedError + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + ) + else: + raise NotImplementedError + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + dino_feature: Optional[torch.FloatTensor] = None + ): + hw_ratio = hidden_states.size(2) / hidden_states.size(3) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + dino_feature, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + hw_ratio, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + dino_feature=dino_feature, + hw_ratio=hw_ratio, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnDownBlockMV2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool=False, + use_dino: bool = False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if selfattn_block == "custom": + from .transformer_mv2d import TransformerMV2DModel + elif selfattn_block == "rowwise": + from .transformer_mv2d_rowwise import TransformerMV2DModel + elif selfattn_block == "self_rowwise": + from .transformer_mv2d_self_rowwise import TransformerMV2DModel + else: + raise NotImplementedError + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + ) + else: + raise NotImplementedError + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + dino_feature: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + hw_ratio = hidden_states.size(2) / hidden_states.size(3) + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + dino_feature, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + hw_ratio, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + dino_feature=dino_feature, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + hw_ratio=hw_ratio, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + diff --git a/multiview/models/unet_mv2d_condition.py b/multiview/models/unet_mv2d_condition.py new file mode 100755 index 0000000000000000000000000000000000000000..0a8b13df7c0768d9d45c8b91a62f91c190b5d59f --- /dev/null +++ b/multiview/models/unet_mv2d_condition.py @@ -0,0 +1,1685 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import os + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, +) +from diffusers.utils import ( + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_model_file, + deprecate, + is_torch_version, + logging, +) +from diffusers.utils.import_utils import is_accelerate_available +from diffusers.utils.hub_utils import HF_HUB_OFFLINE +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE + +from diffusers import __version__ +from .unet_mv2d_blocks import ( + CrossAttnDownBlockMV2D, + CrossAttnUpBlockMV2D, + UNetMidBlockMV2DCrossAttn, + get_down_block, + get_up_block, +) +from einops import rearrange, repeat + +from diffusers import __version__ +from .unet_mv2d_blocks import ( + CrossAttnDownBlockMV2D, + CrossAttnUpBlockMV2D, + UNetMidBlockMV2DCrossAttn, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetMV2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class ResidualBlock(nn.Module): + def __init__(self, dim): + super(ResidualBlock, self).__init__() + self.linear1 = nn.Linear(dim, dim) + self.activation = nn.SiLU() + self.linear2 = nn.Linear(dim, dim) + + def forward(self, x): + identity = x + out = self.linear1(x) + out = self.activation(out) + out = self.linear2(out) + out += identity + out = self.activation(out) + return out + +class ResidualLiner(nn.Module): + def __init__(self, in_features, out_features, dim, act=None, num_block=1): + super(ResidualLiner, self).__init__() + self.linear_in = nn.Sequential(nn.Linear(in_features, dim), nn.SiLU()) + + blocks = nn.ModuleList() + for _ in range(num_block): + blocks.append(ResidualBlock(dim)) + self.blocks = blocks + + self.linear_out = nn.Linear(dim, out_features) + self.act = act + + def forward(self, x): + out = self.linear_in(x) + for block in self.blocks: + out = block(out) + out = self.linear_out(out) + if self.act is not None: + out = self.act(out) + return out + +class BasicConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(BasicConvBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True) + self.act = nn.SiLU() + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True) + self.downsample = nn.Sequential() + if stride != 1 or in_channels != out_channels: + self.downsample = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), + nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True) + ) + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.norm1(out) + out = self.act(out) + out = self.conv2(out) + out = self.norm2(out) + out += self.downsample(identity) + out = self.act(out) + return out + +class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + projection_camera_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + num_views: int = 1, + cd_attention_last: bool = False, + cd_attention_mid: bool = False, + multiview_attention: bool = True, + sparse_mv_attention: bool = False, + selfattn_block: str = "custom", + mvcd_attention: bool = False, + regress_elevation: bool = False, + regress_focal_length: bool = False, + num_regress_blocks: int = 4, + use_dino: bool = False, + addition_downsample: bool = False, + addition_channels: Optional[Tuple[int]] = (1280, 1280, 1280), + ): + super().__init__() + + self.sample_size = sample_size + self.num_views = num_views + self.mvcd_attention = mvcd_attention + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + selfattn_block=selfattn_block, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + # custom MV2D attention block + elif mid_block_type == "UNetMidBlockMV2DCrossAttn": + self.mid_block = UNetMidBlockMV2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + selfattn_block=selfattn_block, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + self.addition_downsample = addition_downsample + if self.addition_downsample: + inc = block_out_channels[-1] + self.downsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.conv_block = nn.ModuleList() + self.conv_block.append(BasicConvBlock(inc, addition_channels[0], stride=1)) + for dim_ in addition_channels[1:-1]: + self.conv_block.append(BasicConvBlock(dim_, dim_, stride=1)) + self.conv_block.append(BasicConvBlock(dim_, inc)) + self.addition_conv_out = nn.Conv2d(inc, inc, kernel_size=1, bias=False) + nn.init.zeros_(self.addition_conv_out.weight.data) + self.addition_act_out = nn.SiLU() + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + + self.regress_elevation = regress_elevation + self.regress_focal_length = regress_focal_length + if regress_elevation or regress_focal_length: + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim) + + regress_in_dim = block_out_channels[-1]*2 if mvcd_attention else block_out_channels + + if regress_elevation: + self.elevation_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks) + if regress_focal_length: + self.focal_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks) + ''' + self.regress_elevation = regress_elevation + self.regress_focal_length = regress_focal_length + if regress_elevation and (not regress_focal_length): + print("Regressing elevation") + cam_dim = 1 + elif regress_focal_length and (not regress_elevation): + print("Regressing focal length") + cam_dim = 6 + elif regress_elevation and regress_focal_length: + print("Regressing both elevation and focal length") + cam_dim = 7 + else: + cam_dim = 0 + assert projection_camera_embeddings_input_dim == 2*cam_dim, "projection_camera_embeddings_input_dim should be 2*cam_dim" + if regress_elevation or regress_focal_length: + self.elevation_regressor = nn.ModuleList([ + nn.Linear(block_out_channels[-1], 1280), + nn.SiLU(), + nn.Linear(1280, 1280), + nn.SiLU(), + nn.Linear(1280, cam_dim) + ]) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.focal_act = nn.Softmax(dim=-1) + self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim) + ''' + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + cd_attention_last=cd_attention_last, + cd_attention_mid=cd_attention_mid, + multiview_attention=multiview_attention, + sparse_mv_attention=sparse_mv_attention, + selfattn_block=selfattn_block, + mvcd_attention=mvcd_attention, + use_dino=use_dino + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + dino_feature: Optional[torch.Tensor] = None, + return_dict: bool = True, + vis_max_min: bool = False, + ) -> Union[UNetMV2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + record_max_min = {} + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + emb_pre_act = emb + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for i, downsample_block in enumerate(self.down_blocks): + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + dino_feature=dino_feature, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + if self.addition_downsample: + global_sample = sample + global_sample = self.downsample(global_sample) + for layer in self.conv_block: + global_sample = layer(global_sample) + global_sample = self.addition_act_out(self.addition_conv_out(global_sample)) + global_sample = self.upsample(global_sample) + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + dino_feature=dino_feature, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + # 4.1 regress elevation and focal length + # # predict elevation -> embed -> projection -> add to time emb + if self.regress_elevation or self.regress_focal_length: + pool_embeds = self.pool(sample.detach()).squeeze(-1).squeeze(-1) # (2B, C) + if self.mvcd_attention: + pool_embeds_normal, pool_embeds_color = torch.chunk(pool_embeds, 2, dim=0) + pool_embeds = torch.cat([pool_embeds_normal, pool_embeds_color], dim=-1) # (B, 2C) + pose_pred = [] + if self.regress_elevation: + ele_pred = self.elevation_regressor(pool_embeds) + ele_pred = rearrange(ele_pred, '(b v) c -> b v c', v=self.num_views) + ele_pred = torch.mean(ele_pred, dim=1) + pose_pred.append(ele_pred) # b, c + + if self.regress_focal_length: + focal_pred = self.focal_regressor(pool_embeds) + focal_pred = rearrange(focal_pred, '(b v) c -> b v c', v=self.num_views) + focal_pred = torch.mean(focal_pred, dim=1) + pose_pred.append(focal_pred) + pose_pred = torch.cat(pose_pred, dim=-1) + # 'e_de_da_sincos', (B, 2) + pose_embeds = torch.cat([ + torch.sin(pose_pred), + torch.cos(pose_pred) + ], dim=-1) + pose_embeds = self.camera_embedding(pose_embeds) + pose_embeds = torch.repeat_interleave(pose_embeds, self.num_views, 0) + if self.mvcd_attention: + pose_embeds = torch.cat([pose_embeds,] * 2, dim=0) + + emb = pose_embeds + emb_pre_act + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + if self.addition_downsample: + sample = sample + global_sample + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + dino_feature=dino_feature, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + if torch.isnan(sample).any() or torch.isinf(sample).any(): + print("NAN in sample, stop training.") + exit() + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + if not return_dict: + return (sample, pose_pred) + if self.regress_elevation or self.regress_focal_length: + return UNetMV2DConditionOutput(sample=sample), pose_pred + else: + return UNetMV2DConditionOutput(sample=sample) + + + @classmethod + def from_pretrained_2d( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + camera_embedding_type: str, num_views: int, sample_size: int, + zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, + projection_camera_embeddings_input_dim: int=2, + cd_attention_last: bool = False, num_regress_blocks: int = 4, + cd_attention_mid: bool = False, multiview_attention: bool = True, + sparse_mv_attention: bool = False, selfattn_block: str = 'custom', mvcd_attention: bool = False, + in_channels: int = 8, out_channels: int = 4, unclip: bool = False, regress_elevation: bool = False, regress_focal_length: bool = False, + init_mvattn_with_selfattn: bool= False, use_dino: bool = False, addition_downsample: bool = False, + **kwargs + ): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors: + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + user_agent=user_agent, + **kwargs, + ) + + # modify config + config["_class_name"] = cls.__name__ + config['in_channels'] = in_channels + config['out_channels'] = out_channels + config['sample_size'] = sample_size # training resolution + config['num_views'] = num_views + config['cd_attention_last'] = cd_attention_last + config['cd_attention_mid'] = cd_attention_mid + config['multiview_attention'] = multiview_attention + config['sparse_mv_attention'] = sparse_mv_attention + config['selfattn_block'] = selfattn_block + config['mvcd_attention'] = mvcd_attention + config["down_block_types"] = [ + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D" + ] + config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" + config["up_block_types"] = [ + "UpBlock2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D" + ] + + + config['regress_elevation'] = regress_elevation # true + config['regress_focal_length'] = regress_focal_length # true + config['projection_camera_embeddings_input_dim'] = projection_camera_embeddings_input_dim # 2 for elevation and 10 for focal_length + config['use_dino'] = use_dino + config['num_regress_blocks'] = num_regress_blocks + config['addition_downsample'] = addition_downsample + # load model + model_file = None + if from_flax: + raise NotImplementedError + else: + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + model = cls.from_config(config, **unused_kwargs) + import copy + state_dict_pretrain = load_state_dict(model_file, variant=variant) + state_dict = copy.deepcopy(state_dict_pretrain) + + if init_mvattn_with_selfattn: + for key in state_dict_pretrain: + if 'attn1' in key: + key_mv = key.replace('attn1', 'attn_mv') + state_dict[key_mv] = state_dict_pretrain[key] + if 'to_out.0.weight' in key: + nn.init.zeros_(state_dict[key_mv].data) + if 'transformer_blocks' in key and 'norm1' in key: # in case that initialize the norm layer in resnet block + key_mv = key.replace('norm1', 'norm_mv') + state_dict[key_mv] = state_dict_pretrain[key] + # del state_dict_pretrain + + model._convert_deprecated_attention_blocks(state_dict) + + conv_in_weight = state_dict['conv_in.weight'] + conv_out_weight = state_dict['conv_out.weight'] + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=True, + ) + if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_in.weight.data[:,:4] = conv_in_weight + + # whether to place all zero to new layers? + if zero_init_conv_in: + model.conv_in.weight.data[:,4:] = 0. + + if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_out.weight.data[:,:4] = conv_out_weight + if out_channels == 8: # copy for the last 4 channels + model.conv_out.weight.data[:, 4:] = conv_out_weight + + if zero_init_camera_projection: # true + params = [p for p in model.camera_embedding.parameters()] + torch.nn.init.zeros_(params[-1].data) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + return model + + @classmethod + def _load_pretrained_model_2d( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs diff --git a/multiview/pipeline_multiclass.py b/multiview/pipeline_multiclass.py new file mode 100644 index 0000000000000000000000000000000000000000..de6829a38343585592bbc1e72383d97169bdfdda --- /dev/null +++ b/multiview/pipeline_multiclass.py @@ -0,0 +1,656 @@ +import inspect +import warnings +from typing import Callable, List, Optional, Union, Dict, Any +import PIL +import torch +import kornia +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPFeatureExtractor, CLIPTextModel +from diffusers.utils.import_utils import is_accelerate_available +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.embeddings import get_timestep_embedding +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +import os +import torchvision.transforms.functional as TF +from einops import rearrange +logger = logging.get_logger(__name__) + + +def CLIP_preprocess(x): + dtype = x.dtype + # following openai's implementation + # TODO HF OpenAI CLIP preprocessing issue https://github.com/huggingface/transformers/issues/22505#issuecomment-1650170741 + # follow openai preprocessing to keep exact same, input tensor [-1, 1], otherwise the preprocessing will be different, https://github.com/huggingface/transformers/pull/22608 + if isinstance(x, torch.Tensor): + if x.min() < -1.0 or x.max() > 1.0: + raise ValueError("Expected input tensor to have values in the range [-1, 1]") + x = kornia.geometry.resize(x.to(torch.float32), (224, 224), interpolation='bicubic', align_corners=True, antialias=False).to(dtype=dtype) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, torch.Tensor([0.48145466, 0.4578275, 0.40821073]), + torch.Tensor([0.26862954, 0.26130258, 0.27577711])) + return x + + +class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for text-guided image to image generation using stable unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + feature_extractor ([`CLIPFeatureExtractor`]): + Feature extractor for image pre-processing before being encoded. + image_encoder ([`CLIPVisionModelWithProjection`]): + CLIP vision model for encoding images. + image_normalizer ([`StableUnCLIPImageNormalizer`]): + Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image + embeddings after the noise has been applied. + image_noising_scheduler ([`KarrasDiffusionSchedulers`]): + Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined + by `noise_level` in `StableUnCLIPPipeline.__call__`. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`KarrasDiffusionSchedulers`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + """ + # image encoding components + feature_extractor: CLIPFeatureExtractor + image_encoder: CLIPVisionModelWithProjection + # image noising components + image_normalizer: StableUnCLIPImageNormalizer + image_noising_scheduler: KarrasDiffusionSchedulers + # regular denoising components + text_encoder: CLIPTextModel + unet: UNet2DConditionModel + scheduler: KarrasDiffusionSchedulers + vae: AutoencoderKL + + def __init__( + self, + # image encoding components + feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection, + # image noising components + image_normalizer: StableUnCLIPImageNormalizer, + image_noising_scheduler: KarrasDiffusionSchedulers, + # regular denoising components + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + # vae + vae: AutoencoderKL, + num_views: int = 4, + ): + super().__init__() + + self.register_modules( + feature_extractor=feature_extractor, + image_encoder=image_encoder, + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + vae=vae, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.num_views: int = num_views + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list + models = [ + self.image_encoder, + self.text_encoder, + self.unet, + self.vae, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + if do_classifier_free_guidance: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + normal_prompt_embeds, color_prompt_embeds = torch.chunk(prompt_embeds, 2, dim=0) + + prompt_embeds = torch.cat([normal_prompt_embeds, normal_prompt_embeds, color_prompt_embeds, color_prompt_embeds], 0) + + return prompt_embeds + + def _encode_image( + self, + # image_pil, + image, + device, + num_images_per_prompt, + do_classifier_free_guidance, + noise_level: int=0, + class_targets: list=None, + generator: Optional[torch.Generator] = None + ): + dtype = next(self.image_encoder.parameters()).dtype + # ______________________________clip image embedding______________________________ + image_ = CLIP_preprocess(image) + image_embeds = self.image_encoder(image_).image_embeds + + image_embeds_ls = [] + + for class_target in class_targets: + image_embeds_ls.append(self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + class_target=class_target, + generator=generator, + ).repeat(num_images_per_prompt, 1)) + + if do_classifier_free_guidance: + for idx in range(len(image_embeds_ls)): + normal_image_embeds, color_image_embeds = torch.chunk(image_embeds_ls[idx], 2, dim=0) + negative_prompt_embeds = torch.zeros_like(normal_image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds_ls[idx] = torch.cat([negative_prompt_embeds, normal_image_embeds, negative_prompt_embeds, color_image_embeds], 0) + + # _____________________________vae input latents__________________________________________________ + image_latents = self.vae.encode(image.to(self.vae.dtype)).latent_dist.mode() * self.vae.config.scaling_factor + # Note: repeat differently from official pipelines + image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1) + + if do_classifier_free_guidance: + normal_image_latents, color_image_latents = torch.chunk(image_latents, 2, dim=0) + image_latents = torch.cat([torch.zeros_like(normal_image_latents), normal_image_latents, + torch.zeros_like(color_image_latents), color_image_latents], 0) + + return image_embeds_ls, image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + noise_level, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = noise.clone() + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, noise + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings + def noise_image_embeddings( + self, + image_embeds: torch.Tensor, + noise_level: int, + class_target: torch.Tensor, + noise: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways + 1. A noise schedule is applied directly to the embeddings + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + dtype = image_embeds.dtype + + image_embeds = self.image_normalizer.scale(image_embeds) + + image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = self.image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + image_embeds = image_embeds.to(dtype=dtype) + noise_level = noise_level.to(image_embeds.dtype) + + image_embeds = torch.cat((image_embeds, class_target.repeat(image_embeds.shape[0] // class_target.shape[0], 1)), 1) + + return image_embeds + + + @torch.no_grad() + def __call__( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + prompt_embeds: torch.FloatTensor = None, + dino_feature: torch.FloatTensor = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 10, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + image_embeds: Optional[torch.FloatTensor] = None, + return_elevation_focal: Optional[bool] = False, + gt_img_in: Optional[torch.FloatTensor] = None, + num_levels: Optional[int] = 3, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which + the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the + latents in the denoising process such as in the standard stable diffusion text guided image variation + process. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to `0`): + The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in + the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in + the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as + `latents`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + height=height, + width=width, + callback_steps=callback_steps, + noise_level=noise_level + ) + + # 2. Define call parameters + if isinstance(image, list): + batch_size = len(image) + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + assert batch_size >= self.num_views and batch_size % self.num_views == 0 + elif isinstance(image, PIL.Image.Image): + image = [image]*self.num_views*2 + batch_size = self.num_views*2 + + if isinstance(prompt, str): + prompt = [prompt] * self.num_views * 2 + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale != 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + + # 4. Encoder input image + noise_level = torch.tensor([noise_level], device=device) + + class_targets = [] + for level in [0, 1, 2]: + class_target = torch.tensor([0, 0, 0, 0]).cuda() + class_target[level] = 1 + class_target = torch.repeat_interleave(class_target, 256).unsqueeze(0) + class_targets.append(class_target) + + image_embeds_ls, image_latents = self._encode_image( + image=image, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + noise_level=noise_level, + class_targets=class_targets, + generator=generator, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.out_channels + if gt_img_in is not None: + latents = gt_img_in * self.scheduler.init_noise_sigma + else: + latents, noise = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + original_latents = latents.clone() + image_ls = [] + now_range = range(1, 3) if num_levels == 2 else range(num_levels) + for level in now_range: + latents = original_latents.clone() + eles, focals = [], [] + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + if do_classifier_free_guidance: + normal_latents, color_latents = torch.chunk(latents, 2, dim=0) + latent_model_input = torch.cat([normal_latents, normal_latents, color_latents, color_latents], 0) + else: + latent_model_input = latents + + latent_model_input = torch.cat([ + latent_model_input, image_latents + ], dim=1) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + unet_out = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + dino_feature=dino_feature, + class_labels=image_embeds_ls[level], + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False) + + noise_pred = unet_out[0] + if return_elevation_focal: + uncond_pose, pose = torch.chunk(unet_out[1], 2, 0) + pose = uncond_pose + guidance_scale * (pose - uncond_pose) + ele = pose[:, 0].detach().cpu().numpy() # b + eles.append(ele) + focal = pose[:, 1].detach().cpu().numpy() + focals.append(focal) + + # perform guidance + if do_classifier_free_guidance: + normal_noise_pred_uncond, normal_noise_pred_text, color_noise_pred_uncond, color_noise_pred_text = torch.chunk(noise_pred, 4, dim=0) + + noise_pred_uncond, noise_pred_text = torch.cat([normal_noise_pred_uncond, color_noise_pred_uncond], 0), torch.cat([normal_noise_pred_text, color_noise_pred_text], 0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + if not output_type == "latent": + if num_channels_latents == 8: + latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0) + with torch.no_grad(): + image = self.vae.decode((latents / self.vae.config.scaling_factor).to(self.vae.dtype), return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + image = ImagePipelineOutput(images=image) + image_ls.append(image) + + return image_ls diff --git a/refine/func.py b/refine/func.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf06edfd58d480df00bf0ee049b228f5921edcb --- /dev/null +++ b/refine/func.py @@ -0,0 +1,427 @@ +import torch +from pytorch3d.renderer.cameras import look_at_view_transform, OrthographicCameras, CamerasBase +from pytorch3d.renderer import ( + RasterizationSettings, + TexturesVertex, + FoVPerspectiveCameras, + FoVOrthographicCameras, +) +from pytorch3d.structures import Meshes +from PIL import Image +from typing import List +from refine.render import _warmup +import pymeshlab as ml +from pymeshlab import Percentage +import nvdiffrast.torch as dr +import numpy as np + + +def _translation(x, y, z, device): + return torch.tensor([[1., 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]],device=device) #4,4 + +def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True): + """ + see https://blog.csdn.net/wodownload2/article/details/85069240/ + """ + if l is None: + l = -r + if t is None: + t = r + if b is None: + b = -t + p = torch.zeros([4,4],device=device) + p[0,0] = 2*n/(r-l) + p[0,2] = (r+l)/(r-l) + p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1) + p[1,2] = (t+b)/(t-b) + p[2,2] = -(f+n)/(f-n) + p[2,3] = -(2*f*n)/(f-n) + p[3,2] = -1 + return p #4,4 + +def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True): + if l is None: + l = -r + if t is None: + t = r + if b is None: + b = -t + o = torch.zeros([4,4],device=device) + o[0,0] = 2/(r-l) + o[0,3] = -(r+l)/(r-l) + o[1,1] = 2/(t-b) * (-1 if flip_y else 1) + o[1,3] = -(t+b)/(t-b) + o[2,2] = -2/(f-n) + o[2,3] = -(f+n)/(f-n) + o[3,3] = 1 + return o #4,4 + +def make_star_cameras(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'): + if r is None: + r = 1/distance + A = az_count + P = pol_count + C = A * P + + phi = torch.arange(0,A) * (2*torch.pi/A) + phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone() + phi_rot[:,0,2,2] = phi.cos() + phi_rot[:,0,2,0] = -phi.sin() + phi_rot[:,0,0,2] = phi.sin() + phi_rot[:,0,0,0] = phi.cos() + + theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2 + theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone() + theta_rot[0,:,1,1] = theta.cos() + theta_rot[0,:,1,2] = -theta.sin() + theta_rot[0,:,2,1] = theta.sin() + theta_rot[0,:,2,2] = theta.cos() + + mv = torch.empty((C,4,4), device=device) + mv[:] = torch.eye(4, device=device) + mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3) + mv = _translation(0, 0, -distance, device) @ mv + + return mv, _projection(r,device) + + +def make_star_cameras_orthographic(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'): + mv, _ = make_star_cameras(az_count,pol_count,distance,r,image_size,device) + if r is None: + r = 1 + return mv, _orthographic(r,device) + + +def get_camera(world_to_cam, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'): + # pytorch3d expects transforms as row-vectors, so flip rotation: https://github.com/facebookresearch/pytorch3d/issues/1183 + R = world_to_cam[:3, :3].t()[None, ...] + T = world_to_cam[:3, 3][None, ...] + if cam_type == 'fov': + camera = FoVPerspectiveCameras(device=world_to_cam.device, R=R, T=T, fov=fov_in_degrees, degrees=True) + else: + focal_length = 1 / focal_length + camera = FoVOrthographicCameras(device=world_to_cam.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length) + return camera + + +def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1): + ret = [] + for azim in azim_list: + R, T = look_at_view_transform(dist, 0, azim) + w2c = torch.cat([R[0].T, T[0, :, None]], dim=1) + cameras: OrthographicCameras = get_camera(w2c, focal_length=focal, cam_type='orthogonal').to(device) + ret.append(cameras) + return ret + + +def to_py3d_mesh(vertices, faces, normals=None): + from pytorch3d.structures import Meshes + from pytorch3d.renderer.mesh.textures import TexturesVertex + mesh = Meshes(verts=[vertices], faces=[faces], textures=None) + if normals is None: + normals = mesh.verts_normals_packed() + # set normals as vertext colors + mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5]) + return mesh + + +def from_py3d_mesh(mesh): + return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed() + + +class Pix2FacesRenderer: + def __init__(self, device="cuda"): + self._glctx = dr.RasterizeCudaContext(device=device) + self.device = device + _warmup(self._glctx, device) + + def transform_vertices(self, meshes: Meshes, cameras: CamerasBase): + vertices = cameras.transform_points_ndc(meshes.verts_padded()) + + perspective_correct = cameras.is_perspective() + znear = cameras.get_znear() + if isinstance(znear, torch.Tensor): + znear = znear.min().item() + z_clip = None if not perspective_correct or znear is None else znear / 2 + + if z_clip: + vertices = vertices[vertices[..., 2] >= cameras.get_znear()][None] # clip + vertices = vertices * torch.tensor([-1, -1, 1]).to(vertices) + vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1).to(torch.float32) + return vertices + + def render_pix2faces_nvdiff(self, meshes: Meshes, cameras: CamerasBase, H=512, W=512): + meshes = meshes.to(self.device) + cameras = cameras.to(self.device) + vertices = self.transform_vertices(meshes, cameras) + faces = meshes.faces_packed().to(torch.int32) + rast_out,_ = dr.rasterize(self._glctx, vertices, faces, resolution=(H, W), grad_db=False) #C,H,W,4 + pix_to_face = rast_out[..., -1].to(torch.int32) - 1 + return pix_to_face + +pix2faces_renderer = Pix2FacesRenderer() + +def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024): + # pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face'] + pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution) + + unique_faces = torch.unique(pix_to_face.flatten()) + unique_faces = unique_faces[unique_faces != -1] + return unique_faces + + +def project_color(meshes: Meshes, cameras: CamerasBase, pil_image: Image.Image, use_alpha=True, eps=0.05, resolution=1024, device="cuda") -> dict: + """ + Projects color from a given image onto a 3D mesh. + + Args: + meshes (pytorch3d.structures.Meshes): The 3D mesh object. + cameras (pytorch3d.renderer.cameras.CamerasBase): The camera object. + pil_image (PIL.Image.Image): The input image. + use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True. + eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05. + resolution (int, optional): The resolution of the projection. Defaults to 1024. + device (str, optional): The device to use for computation. Defaults to "cuda". + debug (bool, optional): Whether to save debug images. Defaults to False. + + Returns: + dict: A dictionary containing the following keys: + - "new_texture" (TexturesVertex): The updated texture with interpolated colors. + - "valid_verts" (Tensor of [M,3]): The indices of the vertices being projected. + - "valid_colors" (Tensor of [M,3]): The interpolated colors for the valid vertices. + """ + meshes = meshes.to(device) + cameras = cameras.to(device) + image = torch.from_numpy(np.array(pil_image.convert("RGBA")) / 255.).permute((2, 0, 1)).float().to(device) # in CHW format of [0, 1.] + unique_faces = get_visible_faces(meshes, cameras, resolution=resolution) + + # visible faces + faces_normals = meshes.faces_normals_packed()[unique_faces] + faces_normals = faces_normals / faces_normals.norm(dim=1, keepdim=True) + world_points = cameras.unproject_points(torch.tensor([[[0., 0., 0.1], [0., 0., 0.2]]]).to(device))[0] + view_direction = world_points[1] - world_points[0] + view_direction = view_direction / view_direction.norm(dim=0, keepdim=True) + + # find invalid faces + cos_angles = (faces_normals * view_direction).sum(dim=1) + assert cos_angles.mean() < 0, f"The view direction is not correct. cos_angles.mean()={cos_angles.mean()}" + selected_faces = unique_faces[cos_angles < -eps] + + # find verts + faces = meshes.faces_packed()[selected_faces] # [N, 3] + verts = torch.unique(faces.flatten()) # [N, 1] + verts_coordinates = meshes.verts_packed()[verts] # [N, 3] + + # compute color + pt_tensor = cameras.transform_points(verts_coordinates)[..., :2] # NDC space points + valid = ~((pt_tensor.isnan()|(pt_tensor<-1)|(1 dict: + """ + meshes: the mesh with vertex color to be completed. + valid_index: the index of the valid vertices, where valid means colors are fixed. [V, 1] + """ + valid_index = valid_index.to(meshes.device) + colors = meshes.textures.verts_features_packed() # [V, 3] + V = colors.shape[0] + + invalid_index = torch.ones_like(colors[:, 0]).bool() # [V] + invalid_index[valid_index] = False + invalid_index = torch.arange(V).to(meshes.device)[invalid_index] + + L = meshes.laplacian_packed() + E = torch.sparse_coo_tensor(torch.tensor([list(range(V))] * 2), torch.ones((V,)), size=(V, V)).to(meshes.device) + L = L + E + # import pdb; pdb.set_trace() + # E = torch.eye(V, layout=torch.sparse_coo, device=meshes.device) + # L = L + E + colored_count = torch.ones_like(colors[:, 0]) # [V] + colored_count[invalid_index] = 0 + L_invalid = torch.index_select(L, 0, invalid_index) # sparse [IV, V] + + total_colored = colored_count.sum() + coloring_round = 0 + stage = "uncolored" + from tqdm import tqdm + pbar = tqdm(miniters=100) + while stage == "uncolored" or coloring_round > 0: + new_color = torch.matmul(L_invalid, colors * colored_count[:, None]) # [IV, 3] + new_count = torch.matmul(L_invalid, colored_count)[:, None] # [IV, 1] + colors[invalid_index] = torch.where(new_count > 0, new_color / new_count, colors[invalid_index]) + colored_count[invalid_index] = (new_count[:, 0] > 0).float() + + new_total_colored = colored_count.sum() + if new_total_colored > total_colored: + total_colored = new_total_colored + coloring_round += 1 + else: + stage = "colored" + coloring_round -= 1 + pbar.update(1) + if coloring_round > 10000: + print("coloring_round > 10000, break") + break + assert not torch.isnan(colors).any() + meshes.textures = TexturesVertex(verts_features=[colors]) + return meshes + + +def multiview_color_projection(meshes: Meshes, image_list: List[Image.Image], cameras_list: List[CamerasBase]=None, camera_focal: float = 2 / 1.35, weights=None, eps=0.05, resolution=1024, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy="smooth", distract_mask=None) -> Meshes: + """ + Projects color from a given image onto a 3D mesh. + + Args: + meshes (pytorch3d.structures.Meshes): The 3D mesh object, only one mesh. + image_list (PIL.Image.Image): List of images. + cameras_list (list): List of cameras. + camera_focal (float, optional): The focal length of the camera, if cameras_list is not passed. Defaults to 2 / 1.35. + weights (list, optional): List of weights for each image, for ['front', 'front_right', 'right', 'back', 'left', 'front_left']. Defaults to None. + eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05. + resolution (int, optional): The resolution of the projection. Defaults to 1024. + device (str, optional): The device to use for computation. Defaults to "cuda". + reweight_with_cosangle (str, optional): Whether to reweight the color with the angle between the view direction and the vertex normal. Defaults to None. + use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True. + confidence_threshold (float, optional): The threshold for the confidence of the projected color, if final projection weight is less than this, we will use the original color. Defaults to 0.1. + complete_unseen (bool, optional): Whether to complete the unseen vertex color using laplacian. Defaults to False. + + Returns: + Meshes: the colored mesh + """ + # 1. preprocess inputs + if image_list is None: + raise ValueError("image_list is None") + if cameras_list is None: + raise ValueError("cameras_list is None") + if weights is None: + raise ValueError("weights is None, and can not be guessed from image_list") + + # 2. run projection + meshes = meshes.clone().to(device) + if weights is None: + weights = [1. for _ in range(len(cameras_list))] + assert len(cameras_list) == len(image_list) == len(weights) + original_color = meshes.textures.verts_features_packed() + assert not torch.isnan(original_color).any() + texture_counts = torch.zeros_like(original_color[..., :1]) + texture_values = torch.zeros_like(original_color) + max_texture_counts = torch.zeros_like(original_color[..., :1]) + max_texture_values = torch.zeros_like(original_color) + for camera, image, weight in zip(cameras_list, image_list, weights): + ret = project_color(meshes, camera, image, eps=eps, resolution=resolution, device=device, use_alpha=use_alpha) + if reweight_with_cosangle == "linear": + weight = (ret['cos_angles'].abs() * weight)[:, None] + elif reweight_with_cosangle == "square": + weight = (ret['cos_angles'].abs() ** 2 * weight)[:, None] + if use_alpha: + weight = weight * ret['valid_alpha'] + assert weight.min() > -0.0001 + texture_counts[ret['valid_verts']] += weight + texture_values[ret['valid_verts']] += ret['valid_colors'] * weight + max_texture_values[ret['valid_verts']] = torch.where(weight > max_texture_counts[ret['valid_verts']], ret['valid_colors'], max_texture_values[ret['valid_verts']]) + max_texture_counts[ret['valid_verts']] = torch.max(max_texture_counts[ret['valid_verts']], weight) + + # Method2 + texture_values = torch.where(texture_counts > confidence_threshold, texture_values / texture_counts, texture_values) + if below_confidence_strategy == "smooth": + texture_values = torch.where(texture_counts <= confidence_threshold, (original_color * (confidence_threshold - texture_counts) + texture_values) / confidence_threshold, texture_values) + elif below_confidence_strategy == "original": + texture_values = torch.where(texture_counts <= confidence_threshold, original_color, texture_values) + else: + raise ValueError(f"below_confidence_strategy={below_confidence_strategy} is not supported") + assert not torch.isnan(texture_values).any() + meshes.textures = TexturesVertex(verts_features=[texture_values]) + + if distract_mask is not None: + import cv2 + pil_distract_mask = (distract_mask * 255).astype(np.uint8) + pil_distract_mask = cv2.erode(pil_distract_mask, np.ones((3, 3), np.uint8), iterations=2) + pil_distract_mask = Image.fromarray(pil_distract_mask) + ret = project_color(meshes, cameras_list[0], pil_distract_mask, eps=eps, resolution=resolution, device=device, use_alpha=use_alpha) + distract_valid_mask = ret['valid_colors'][:, 0] > 0.5 + distract_invalid_index = ret['valid_verts'][~distract_valid_mask] + + # invalid index's neighbors also should included + L = meshes.laplacian_packed() + # Convert invalid indices to a boolean mask + distract_invalid_mask = torch.zeros(meshes.verts_packed().shape[0:1], dtype=torch.bool, device=device) + distract_invalid_mask[distract_invalid_index] = True + + # Find neighbors: multiply Laplacian with invalid_mask and check non-zero values + # Extract COO format (L.indices() gives [2, N] shape: row, col; L.values() gives values) + row_indices, col_indices = L.coalesce().indices() + invalid_rows = distract_invalid_mask[row_indices] + neighbor_indices = col_indices[invalid_rows] + + # Combine original invalids with their neighbors + combined_invalid_mask = distract_invalid_mask.clone() + combined_invalid_mask[neighbor_indices] = True + + # repeat + invalid_rows = combined_invalid_mask[row_indices] + neighbor_indices = col_indices[invalid_rows] + combined_invalid_mask[neighbor_indices] = True + + # Apply to texture counts and values + texture_counts[combined_invalid_mask] = 0 + texture_values[combined_invalid_mask] = 0 + + + if complete_unseen: + meshes = complete_unseen_vertex_color(meshes, torch.arange(texture_values.shape[0]).to(device)[texture_counts[:, 0] >= confidence_threshold]) + ret_mesh = meshes.detach() + del meshes + return ret_mesh + + +def meshlab_mesh_to_py3dmesh(mesh: ml.Mesh) -> Meshes: + verts = torch.from_numpy(mesh.vertex_matrix()).float() + faces = torch.from_numpy(mesh.face_matrix()).long() + colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float() + textures = TexturesVertex(verts_features=[colors]) + return Meshes(verts=[verts], faces=[faces], textures=textures) + + +def to_pyml_mesh(vertices,faces): + m1 = ml.Mesh( + vertex_matrix=vertices.cpu().float().numpy().astype(np.float64), + face_matrix=faces.cpu().long().numpy().astype(np.int32), + ) + return m1 + + +def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25): + ms = ml.MeshSet() + ms.add_mesh(pyml_mesh, "cube_mesh") + + if apply_smooth: + ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False) + if apply_sub_divide: # 5s, slow + ms.apply_filter("meshing_repair_non_manifold_vertices") + ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces') + ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=Percentage(sub_divide_threshold)) + return meshlab_mesh_to_py3dmesh(ms.current_mesh()) diff --git a/refine/mesh_refine.py b/refine/mesh_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..7a55add24f8199436e76e436765ed06a36b68dda --- /dev/null +++ b/refine/mesh_refine.py @@ -0,0 +1,319 @@ +import torch +import numpy as np +import trimesh +from PIL import Image +from typing import List +from tqdm import tqdm +from sklearn.neighbors import KDTree + +from refine.func import from_py3d_mesh, get_cameras_list, make_star_cameras_orthographic, multiview_color_projection, simple_clean_mesh, to_py3d_mesh, to_pyml_mesh +from refine.opt import MeshOptimizer +from refine.render import NormalsRenderer, calc_vertex_normals + +import pytorch3d +from pytorch3d.structures import Meshes + +def remove_color(arr): + if arr.shape[-1] == 4: + arr = arr[..., :3] + # calc diffs + base = arr[0, 0] + diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1) + alpha = (diffs <= 80) + + arr[alpha] = 255 + alpha = ~alpha + arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1) + return arr + + +def simple_remove(imgs): + """Only works for normal""" + if not isinstance(imgs, list): + imgs = [imgs] + single_input = True + else: + single_input = False + rets = [] + for img in imgs: + arr = np.array(img) + arr = remove_color(arr) + rets.append(Image.fromarray(arr.astype(np.uint8))) + if single_input: + return rets[0] + return rets + + +def erode_alpha(img_list): + out_img_list = [] + for idx, img in enumerate(img_list): + arr = np.array(img) + alpha = (arr[:, :, 3] > 127).astype(np.uint8) + # erode 1px + import cv2 + alpha = cv2.erode(alpha, np.ones((3, 3), np.uint8), iterations=1) + alpha = (alpha * 255).astype(np.uint8) + img = Image.fromarray(np.concatenate([arr[:, :, :3], alpha[:, :, None]], axis=-1)) + out_img_list.append(img) + return out_img_list + + +def merge_small_faces(mesh, thres=1e-5): + area_faces = mesh.area_faces + small_faces = area_faces < thres + + vertices = mesh.vertices + faces = mesh.faces + + new_vertices = vertices.tolist() + vertex_mapping = {} + + for face_idx in np.where(small_faces)[0]: + face = faces[face_idx] + v1, v2, v3 = face + center = np.mean(vertices[face], axis=0) + + new_vertex_idx = len(new_vertices) + new_vertices.append(center) + + vertex_mapping[v1] = new_vertex_idx + vertex_mapping[v2] = new_vertex_idx + vertex_mapping[v3] = new_vertex_idx + + for k,v in vertex_mapping.items(): + faces[faces == k] = v + + faces = faces[~small_faces] + + new_mesh = trimesh.Trimesh(vertices=new_vertices, faces=faces, postprocess=False) + new_mesh.remove_unreferenced_vertices() + new_mesh.remove_degenerate_faces() + new_mesh.remove_duplicate_faces() + + return new_mesh + + +def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"): + # Convert the background color to a PyTorch tensor + new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device) + + # Convert all images to PyTorch tensors and process them + imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255 + img_nps = imgs[..., :3] + alpha_nps = imgs[..., 3] + ori_bkgds = img_nps[:, :1, :1] + + # Avoid divide by zero and calculate the original image + alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1) + ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1) + ori_img_nps = torch.clamp(ori_img_nps, 0, 1) + img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd) + + rgba_img_np = torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1) + return rgba_img_np + + +def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, fixed_v=None, fixed_f=None, lr=0.03, start_edge_len=0.15, end_edge_len=0.005, + decay=0.995, loss_expansion_weight=0.1, gain=0.1, remesh_interval=1, remesh_start=0, distract_mask=None, distract_bbox=None): + vertices, faces = vertices.cuda(), faces.cuda() + assert len(pils) == 6 + mv, proj = make_star_cameras_orthographic(8, 1, r=1.2) + mv = mv[[4, 3, 2, 0, 6, 5]] + renderer = NormalsRenderer(mv,proj,list(pils[0].size)) + + target_images = init_target(pils, new_bkgd=(0., 0., 0.)) + + # init from coarse mesh + opt = MeshOptimizer(vertices, faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len), lr=lr, + remesh_interval=remesh_interval, remesh_start=remesh_start) + + _vertices = opt.vertices + _faces = opt.faces + + if fixed_v is not None and fixed_f is not None: + kdtree = KDTree(fixed_v.cpu().numpy()) + + mask = target_images[..., -1] < 0.5 + + for i in tqdm(range(steps)): + faces = torch.cat([_faces, fixed_f + len(_vertices)], dim=0) if fixed_f is not None else _faces + vertices = torch.cat([_vertices, fixed_v], dim=0) if fixed_v is not None else _vertices + + opt.zero_grad() + opt._lr *= decay + normals = calc_vertex_normals(vertices,faces) + + normals[:, 0] *= -1 + normals[:, 2] *= -1 + + images = renderer.render(vertices,normals,faces) + loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean() + + t_mask = images[..., -1] > 0.5 + loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean() + loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean() + + loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight + + if distract_mask is not None: + hair_visible_normals = normals + hair_visible_normals[len(_vertices):] = -1. + _images = renderer.render(vertices,hair_visible_normals,faces) + loss_distract = (_images[0][distract_mask] - target_images[0][distract_mask]).pow(2).mean() + + target_outside = target_images[0][..., :3].clone() + target_outside[~distract_mask] = 0. + + loss_outside_distract = (_images[0][..., :3][~distract_mask] - target_outside[..., :3][~distract_mask]).pow(2).mean() + + loss = loss + loss_distract * 1. + loss_outside_distract * 10. + + if fixed_v is not None and fixed_f is not None: + _, idx = kdtree.query(_vertices.detach().cpu().numpy(), k=1) + idx = idx.squeeze() + anchors = fixed_v[idx].detach() + + normals_fixed = calc_vertex_normals(fixed_v, fixed_f) + loss_anchor = (torch.clamp(((anchors - _vertices) * normals_fixed[idx]).sum(-1), min=-0)+0).pow(3) + loss_anchor_dist_mask = (anchors - _vertices).norm(dim=-1) < 0.05 + loss_anchor = loss_anchor[loss_anchor_dist_mask].mean() + + loss = loss + loss_anchor * 100. + + # out of box + loss_oob = (vertices.abs() > 0.99).float().mean() * 10 + loss = loss + loss_oob + + loss.backward() + opt.step() + + if i % remesh_interval == 0 and i >= remesh_start: + _vertices,_faces = opt.remesh(poisson=False) + + vertices, faces = opt._vertices.detach(), opt._faces.detach() + + return vertices, faces + + +def run_mesh_refine(vertices, faces, pils: List[Image.Image], fixed_v=None, fixed_f=None, steps=100, start_edge_len=0.02, end_edge_len=0.005, + decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True, remesh_interval=20): + poission_steps = [] + + assert len(pils) == 6 + mv, proj = make_star_cameras_orthographic(8, 1, r=1.2) + mv = mv[[4, 3, 2, 0, 6, 5]] + renderer = NormalsRenderer(mv,proj,list(pils[0].size)) + + target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s + + # init from coarse mesh + opt = MeshOptimizer(vertices, faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02) + + _vertices = opt.vertices + _faces = opt.faces + alpha_init = None + + mask = target_images[..., -1] < 0.5 + + for i in tqdm(range(steps)): + faces = torch.cat([_faces, fixed_f + len(_vertices)], dim=0) if fixed_f is not None else _faces + vertices = torch.cat([_vertices, fixed_v], dim=0) if fixed_v is not None else _vertices + + opt.zero_grad() + opt._lr *= decay + normals = calc_vertex_normals(vertices,faces) + images = renderer.render(vertices,normals,faces) + if alpha_init is None: + alpha_init = images.detach() + + if i < update_warmup or i % update_normal_interval == 0: + with torch.no_grad(): + py3d_mesh = to_py3d_mesh(vertices, faces, normals) + cameras = get_cameras_list(azim_list = [180, 225, 270, 0, 90, 135], device=vertices.device, focal=1/1.2) + _, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=[2,0.8,0.8,2,0.8,0.8], confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear')) + target_normal = target_normal * 2 - 1 + target_normal = torch.nn.functional.normalize(target_normal, dim=-1) + + target_normal[:, 0] *= -1 + target_normal[:, 2] *= -1 + + debug_images = renderer.render(vertices,target_normal,faces) + + d_mask = images[..., -1] > 0.5 + loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean() + + loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean() + + loss = loss_debug_l2 + loss_alpha_target_mask_l2 + + # out of box + loss_oob = (vertices.abs() > 0.99).float().mean() * 10 + loss = loss + loss_oob + + loss.backward() + opt.step() + + if i % remesh_interval == 0: + _vertices,_faces = opt.remesh(poisson=(i in poission_steps)) + + vertices, faces = opt._vertices.detach(), opt._faces.detach() + + if process_outputs: + vertices = vertices / 2 * 1.35 + vertices[..., [0, 2]] = - vertices[..., [0, 2]] + + return vertices, faces + + +def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=None, fixed_f=None, + distract_mask=None, distract_bbox=None, thres=3e-6, no_decompose=False): + rm_normals = simple_remove(normal_ls) + + # transfer the alpha channel of rm_normals to img_list + for idx, img in enumerate(rm_normals): + rgb_ls[idx] = Image.fromarray(np.concatenate([np.array(rgb_ls[idx])[..., :3], np.array(img)[:, :, 3:4]], axis=-1)) + assert np.mean(np.array(rgb_ls[0])[..., 3]) < 250 + + rgb_ls = erode_alpha(rgb_ls) + + stage1_lr = 0.08 if fixed_v is None else 0.01 + stage1_remesh_interval = 1 if fixed_v is None else 30 + + if no_decompose: + stage1_lr = 0.03 + stage1_remesh_interval = 30 + + vertices, faces = reconstruct_stage1(rm_normals, steps=200, vertices=mesh_v, faces=mesh_f, fixed_v=fixed_v, fixed_f=fixed_f, + lr=stage1_lr, remesh_interval=stage1_remesh_interval, start_edge_len=0.02, + end_edge_len=0.005, gain=0.05, loss_expansion_weight=expansion_weight, + distract_mask=distract_mask, distract_bbox=distract_bbox) + + vertices, faces = run_mesh_refine(vertices, faces, rm_normals, fixed_v=fixed_v, fixed_f=fixed_f, steps=100, start_edge_len=0.005, end_edge_len=0.0002, + decay=0.99, update_normal_interval=20, update_warmup=5, process_inputs=False, process_outputs=False, remesh_interval=1) + meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=2, apply_sub_divide=False, sub_divide_threshold=0.25).to("cuda") + # subdivide meshes + simp_vertices, simp_faces = meshes.verts_packed(), meshes.faces_packed() + vertices, faces = simp_vertices.detach().cpu().numpy(), simp_faces.detach().cpu().numpy() + + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) + mesh = merge_small_faces(mesh, thres=thres) + new_mesh = mesh.split(only_watertight=False) + + new_mesh = [ j for j in new_mesh if len(j.vertices) >= 200 ] + mesh = trimesh.Scene(new_mesh).dump(concatenate=True) + vertices, faces = mesh.vertices.astype('float32'), mesh.faces + + vertices, faces = trimesh.remesh.subdivide(vertices, faces) + origin_len_v, origin_len_f = len(vertices), len(faces) + # concatenate fixed_v and fixed_f + if fixed_v is not None and fixed_f is not None: + vertices, faces = np.concatenate([vertices, fixed_v.detach().cpu().numpy()], axis=0), np.concatenate([faces, fixed_f.detach().cpu().numpy() + len(vertices)], axis=0) + vertices, faces = torch.tensor(vertices, device='cuda'), torch.tensor(faces, device='cuda') + # reconstruct meshes + meshes = Meshes(verts=[vertices], faces=[faces], textures=pytorch3d.renderer.mesh.textures.TexturesVertex([torch.zeros_like(vertices).float()])) + new_meshes = multiview_color_projection(meshes, rgb_ls, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([180, 225, 270, 0, 90, 135], "cuda", focal=1/1.2), weights=[2.0, 0.5, 0.0, 1.0, 0.0, 0.5] if distract_mask is None else [2.0, 0.0, 0.5, 1.0, 0.5, 0.0], distract_mask=distract_mask) + # exclude fixed_v and fixed_f + if fixed_v is not None and fixed_f is not None: + new_meshes = Meshes(verts=[new_meshes.verts_packed()[:origin_len_v]], faces=[new_meshes.faces_packed()[:origin_len_f]], + textures=pytorch3d.renderer.mesh.textures.TexturesVertex([new_meshes.textures.verts_features_packed()[:origin_len_v]])) + return new_meshes, simp_vertices, simp_faces diff --git a/refine/opt.py b/refine/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..06f7160b6fac281dc6de40acc871403cd2b10636 --- /dev/null +++ b/refine/opt.py @@ -0,0 +1,192 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import time +import torch +import torch_scatter +from typing import Tuple +from .remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges + +@torch.no_grad() +def remesh( + vertices_etc:torch.Tensor, #V,D + faces:torch.Tensor, #F,3 long + min_edgelen:torch.Tensor, #V + max_edgelen:torch.Tensor, #V + flip:bool, + max_vertices=1e7 + ): + + # dummies + vertices_etc,faces = prepend_dummies(vertices_etc,faces) + vertices = vertices_etc[:,:3] #V,3 + nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device) + min_edgelen = torch.concat((nan_tensor,min_edgelen)) + max_edgelen = torch.concat((nan_tensor,max_edgelen)) + + # collapse + edges,face_to_edge = calc_edges(faces) #E,2 F,3 + edge_length = calc_edge_length(vertices,edges) #E + face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3 + vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3 + face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5) + shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0 + priority = face_collapse.float() + shortness + vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority) + + # split + if vertices.shape[0] max_edgelen[edges].mean(dim=-1) + vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False) + + vertices_etc,faces = pack(vertices_etc,faces) + vertices = vertices_etc[:,:3] + + if flip: + edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3 + flip_edges(vertices,faces,edges,edge_to_face,with_border=False) + + return remove_dummies(vertices_etc,faces) + +def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int): + """lerp with adam's bias correction""" + c_prev = 1-weight**(step-1) + c = 1-weight**step + a_weight = weight*c_prev/c + b_weight = (1-weight)/c + a.mul_(a_weight).add_(b, alpha=b_weight) + + +class MeshOptimizer: + """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh().""" + + def __init__(self, + vertices:torch.Tensor, #V,3 + faces:torch.Tensor, #F,3 + lr=0.03, #learning rate + betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu + gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing) + nu_ref=0.3, #reference velocity for edge length controller + edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length + edge_len_tol=.5, #edge length tolerance for split and collapse + gain=.2, #gain value for edge length controller + laplacian_weight=.02, #for laplacian smoothing/regularization + ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0]) + grad_lim=10., #gradients are clipped to m1.abs()*grad_lim + remesh_interval=1, #larger intervals are faster but with worse mesh quality + remesh_start=0, + local_edgelen=True, #set to False to use a global scalar reference edge length instead + ): + self._vertices = vertices + self._faces = faces + self._lr = lr + self._betas = betas + self._gammas = gammas + self._nu_ref = nu_ref + self._edge_len_lims = edge_len_lims + self._edge_len_tol = edge_len_tol + self._gain = gain + self._laplacian_weight = laplacian_weight + self._ramp = ramp + self._grad_lim = grad_lim + self._remesh_interval = remesh_interval + self._remesh_start = remesh_start + self._local_edgelen = local_edgelen + self._step = 0 + + V = self._vertices.shape[0] + # prepare continuous tensor for all vertex-based data + self._vertices_etc = torch.zeros([V,9],device=vertices.device) + self._split_vertices_etc() + self.vertices.copy_(vertices) #initialize vertices + self._vertices.requires_grad_() + self._ref_len.fill_(edge_len_lims[1]) + + @property + def vertices(self): + return self._vertices + + @property + def faces(self): + return self._faces + + def _split_vertices_etc(self): + self._vertices = self._vertices_etc[:,:3] + self._m2 = self._vertices_etc[:,3] + self._nu = self._vertices_etc[:,4] + self._m1 = self._vertices_etc[:,5:8] + self._ref_len = self._vertices_etc[:,8] + + with_gammas = any(g!=0 for g in self._gammas) + self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3] + + def zero_grad(self): + self._vertices.grad = None + + @torch.no_grad() + def step(self): + + eps = 1e-8 + + self._step += 1 + + # spatial smoothing + edges,_ = calc_edges(self._faces) #E,2 + E = edges.shape[0] + edge_smooth = self._smooth[edges] #E,2,S + neighbor_smooth = torch.zeros_like(self._smooth) #V,S + torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth) + + #apply optional smoothing of m1,m2,nu + if self._gammas[0]: + self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0]) + if self._gammas[1]: + self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1]) + if self._gammas[2]: + self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2]) + + #add laplace smoothing to gradients + laplace = self._vertices - neighbor_smooth[:,:3] + grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight) + + #gradient clipping + if self._step>1: + grad_lim = self._m1.abs().mul_(self._grad_lim) + grad.clamp_(min=-grad_lim,max=grad_lim) + + # moment updates + lerp_unbiased(self._m1, grad, self._betas[0], self._step) + lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step) + + velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3 + speed = velocity.norm(dim=-1) #V + + if self._betas[2]: + lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V + else: + self._nu.copy_(speed) #V + + # update vertices + ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp) + self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr) + + # update target edge length + if self._step % self._remesh_interval == 0 and self._step >= self._remesh_start: + if self._local_edgelen: + len_change = (1 + (self._nu - self._nu_ref) * self._gain) + else: + len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain) + self._ref_len *= len_change + self._ref_len.clamp_(*self._edge_len_lims) + + def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]: + min_edge_len = self._ref_len * (1 - self._edge_len_tol) + max_edge_len = self._ref_len * (1 + self._edge_len_tol) + + self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e7) + + self._split_vertices_etc() + self._vertices.requires_grad_() + + return self._vertices, self._faces \ No newline at end of file diff --git a/refine/remesh.py b/refine/remesh.py new file mode 100644 index 0000000000000000000000000000000000000000..4aca51e300a77e6528fd5eaa234a0dbc3be44529 --- /dev/null +++ b/refine/remesh.py @@ -0,0 +1,361 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import torch +import torch.nn.functional as tfunc +import torch_scatter +from typing import Tuple + +def prepend_dummies( + vertices:torch.Tensor, #V,D + faces:torch.Tensor, #F,3 long + )->Tuple[torch.Tensor,torch.Tensor]: + """prepend dummy elements to vertices and faces to enable "masked" scatter operations""" + V,D = vertices.shape + vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0) + faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0) + return vertices,faces + +def remove_dummies( + vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced + faces:torch.Tensor, #F,3 long - first face all zeros + )->Tuple[torch.Tensor,torch.Tensor]: + """remove dummy elements added with prepend_dummies()""" + return vertices[1:],faces[1:]-1 + + +def calc_edges( + faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros + with_edge_to_face: bool = False + ) -> Tuple[torch.Tensor, ...]: + """ + returns Tuple of + - edges E,2 long, 0 for unused, lower vertex index first + - face_to_edge F,3 long + - (optional) edge_to_face shape=E,[left,right],[face,side] + + o-<-----e1 e0,e1...edge, e0-o + """ + + F = faces.shape[0] + + # make full edges, lower vertex index first + face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2 + full_edges = face_edges.reshape(F*3,2) + sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2 + + # make unique edges + edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3) + E = edges.shape[0] + face_to_edge = full_to_unique.reshape(F,3) #F,3 + + if not with_edge_to_face: + return edges, face_to_edge + + is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3 + edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2 + scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2 + edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2 + edge_to_face[0] = 0 + return edges, face_to_edge, edge_to_face + +def calc_edge_length( + vertices:torch.Tensor, #V,3 first may be dummy + edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused + )->torch.Tensor: #E + + full_vertices = vertices[edges] #E,2,3 + a,b = full_vertices.unbind(dim=1) #E,3 + return torch.norm(a-b,p=2,dim=-1) + +def calc_face_normals( + vertices:torch.Tensor, #V,3 first vertex may be unreferenced + faces:torch.Tensor, #F,3 long, first face may be all zero + normalize:bool=False, + )->torch.Tensor: #F,3 + """ + n + | + c0 corners ordered counterclockwise when + / \ looking onto surface (in neg normal direction) + c1---c2 + """ + full_vertices = vertices[faces] #F,C=3,3 + v0,v1,v2 = full_vertices.unbind(dim=1) #F,3 + face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3 + if normalize: + face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) + return face_normals #F,3 + +def calc_vertex_normals( + vertices:torch.Tensor, #V,3 first vertex may be unreferenced + faces:torch.Tensor, #F,3 long, first face may be all zero + face_normals:torch.Tensor=None, #F,3, not normalized + )->torch.Tensor: #F,3 + + F = faces.shape[0] + + if face_normals is None: + face_normals = calc_face_normals(vertices,faces) + + vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3 + vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3)) + vertex_normals = vertex_normals.sum(dim=1) #V,3 + return tfunc.normalize(vertex_normals, eps=1e-6, dim=1) + +def calc_face_ref_normals( + faces:torch.Tensor, #F,3 long, 0 for unused + vertex_normals:torch.Tensor, #V,3 first unused + normalize:bool=False, + )->torch.Tensor: #F,3 + """calculate reference normals for face flip detection""" + full_normals = vertex_normals[faces] #F,C=3,3 + ref_normals = full_normals.sum(dim=1) #F,3 + if normalize: + ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1) + return ref_normals + +def pack( + vertices:torch.Tensor, #V,3 first unused and nan + faces:torch.Tensor, #F,3 long, 0 for unused + )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused + """removes unused elements in vertices and faces""" + V = vertices.shape[0] + + # remove unused faces + used_faces = faces[:,0]!=0 + used_faces[0] = True + faces = faces[used_faces] #sync + + # remove unused vertices + used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device) + used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') + used_vertices = used_vertices.any(dim=1) + used_vertices[0] = True + vertices = vertices[used_vertices] #sync + + # update used faces + ind = torch.zeros(V,dtype=torch.long,device=vertices.device) + V1 = used_vertices.sum() + ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync + faces = ind[faces] + + return vertices,faces + +def split_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + face_to_edge:torch.Tensor, #F,3 long 0 for unused + splits, #E bool + pack_faces:bool=True, + )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces) + + # c2 c2 c...corners = faces + # . . . . s...side_vert, 0 means no split + # . . .N2 . S...shrunk_face + # . . . . Ni...new_faces + # s2 s1 s2|c2...s1|c1 + # . . . . . + # . . . S . . + # . . . . N1 . + # c0...(s0=0)....c1 s0|c0...........c1 + # + # pseudo-code: + # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2] + # split = side_vert!=0 example:[False,True,True] + # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0] + # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0] + # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1] + + V = vertices.shape[0] + F = faces.shape[0] + S = splits.sum().item() #sync + + if S==0: + return vertices,faces + + edge_vert = torch.zeros_like(splits, dtype=torch.long) #E + edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync + side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split + split_edges = edges[splits] #S sync + + #vertices + split_vertices = vertices[split_edges].mean(dim=1) #S,3 + vertices = torch.concat((vertices,split_vertices),dim=0) + + #faces + side_split = side_vert!=0 #F,3 + shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split + new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3 + faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3 + if pack_faces: + mask = faces[:,0]!=0 + mask[0] = True + faces = faces[mask] #F',3 sync + + return vertices,faces + +def collapse_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + priorities:torch.Tensor, #E float + stable:bool=False, #only for unit testing + )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces) + + V = vertices.shape[0] + + # check spacing + _,order = priorities.sort(stable=stable) #E + rank = torch.zeros_like(order) + rank[order] = torch.arange(0,len(rank),device=rank.device) + vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V + edge_rank = rank #E + for i in range(3): + torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank) + edge_rank,_ = vert_rank[edges].max(dim=-1) #E + candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2 + + # check connectivity + vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V + vert_connections[candidates[:,0]] = 1 #start + edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start + vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start + vert_connections[candidates] = 0 #clear start and end + edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start + vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start + collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end + + # mean vertices + vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) + + # update faces + dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V + dest[collapses[:,1]] = dest[collapses[:,0]] + faces = dest[faces] #F,3 + c0,c1,c2 = faces.unbind(dim=-1) + collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2) + faces[collapsed] = 0 + + return vertices,faces + +def calc_face_collapses( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + face_to_edge:torch.Tensor, #F,3 long 0 for unused + edge_length:torch.Tensor, #E + face_normals:torch.Tensor, #F,3 + vertex_normals:torch.Tensor, #V,3 first unused + min_edge_length:torch.Tensor=None, #V + area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio + shortest_probability = 0.8 + )->torch.Tensor: #E edges to collapse + + E = edges.shape[0] + F = faces.shape[0] + + # face flips + ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3 + face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F + + # small faces + if min_edge_length is not None: + min_face_length = min_edge_length[faces].mean(dim=-1) #F + min_area = min_face_length**2 * area_ratio #F + face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F + face_collapses[0] = False + + # faces to edges + face_length = edge_length[face_to_edge] #F,3 + + if shortest_probability<1: + #select shortest edge with shortest_probability chance + randlim = round(2/(1-shortest_probability)) + rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face + sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3 + local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None]) + else: + local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face + + edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index + edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device) + edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) + + return edge_collapses.bool() + +def flip_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused + edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first + edge_to_face:torch.Tensor, #E,[left,right],[face,side] + with_border:bool=True, #handle border edges (D=4 instead of D=6) + with_normal_check:bool=True, #check face normal flips + stable:bool=False, #only for unit testing + ): + V = vertices.shape[0] + E = edges.shape[0] + device=vertices.device + vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long + vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add') + neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner + neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2 + edge_is_inside = neighbors.all(dim=-1) #E + + if with_border: + # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices + # need to use float for masks in order to use scatter(reduce='multiply') + vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float + src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float + vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply') + vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long + vertex_degree -= 2 * vertex_is_inside #V long + + neighbor_degrees = vertex_degree[neighbors] #E,LR=2 + edge_degrees = vertex_degree[edges] #E,2 + # + # loss = Sum_over_affected_vertices((new_degree-6)**2) + # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2) + # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2) + # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree)) + # + loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E + candidates = torch.logical_and(loss_change<0, edge_is_inside) #E + loss_change = loss_change[candidates] #E' + if loss_change.shape[0]==0: + return + + edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4 + _,order = loss_change.sort(descending=True, stable=stable) #E' + rank = torch.zeros_like(order) + rank[order] = torch.arange(0,len(rank),device=rank.device) + vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4 + torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank) + vertex_rank,_ = vertex_rank.max(dim=-1) #V + neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E' + flip = rank==neighborhood_rank #E' + + if with_normal_check: + # cl-<-----e1 e0,e1...edge, e0-cr + v = vertices[edges_neighbors] #E",4,3 + v = v - v[:,0:1] #make relative to e0 + e1 = v[:,1] + cl = v[:,2] + cr = v[:,3] + n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors + flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face + flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face + + flip_edges_neighbors = edges_neighbors[flip] #E",4 + flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2 + flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3 + faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3)) \ No newline at end of file diff --git a/refine/render.py b/refine/render.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae06c7a74d48cd9511011739180c67039b403fa --- /dev/null +++ b/refine/render.py @@ -0,0 +1,203 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import nvdiffrast.torch as dr +import torch +from typing import Tuple +import torch.nn.functional as tfunc + + +def _warmup(glctx, device=None): + device = 'cuda' if device is None else device + #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59 + def tensor(*args, **kwargs): + return torch.tensor(*args, device=device, **kwargs) + pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32) + tri = tensor([[0, 1, 2]], dtype=torch.int32) + dr.rasterize(glctx, pos, tri, resolution=[256, 256]) + +glctx = dr.RasterizeCudaContext(device="cuda") + +class NormalsRenderer: + + _glctx:dr.RasterizeCudaContext = None + + def __init__( + self, + mv: torch.Tensor, #C,4,4 + proj: torch.Tensor, #C,4,4 + image_size: Tuple[int,int], + mvp = None, + device=None, + ): + if mvp is None: + self._mvp = proj @ mv #C,4,4 + else: + self._mvp = mvp + self._image_size = image_size + self._glctx = glctx + _warmup(self._glctx, device) + + def render(self, + vertices: torch.Tensor, #V,3 float + normals: torch.Tensor, #V,3 float in [-1, 1] + faces: torch.Tensor, #F,3 long + ) ->torch.Tensor: #C,H,W,4 + + V = vertices.shape[0] + faces = faces.type(torch.int32) + vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4 + vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4 + + rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4 + vert_col = (normals+1)/2 #V,3 + col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3 + alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1 + col = torch.concat((col,alpha),dim=-1) #C,H,W,4 + col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4 + return col #C,H,W,4 + + + +from pytorch3d.structures import Meshes +from pytorch3d.renderer.mesh.shader import ShaderBase +from pytorch3d.renderer import ( + RasterizationSettings, + MeshRendererWithFragments, + TexturesVertex, + MeshRasterizer, + BlendParams, + FoVOrthographicCameras, + look_at_view_transform, + hard_rgb_blend, +) + +class VertexColorShader(ShaderBase): + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + blend_params = kwargs.get("blend_params", self.blend_params) + texels = meshes.sample_textures(fragments) + return hard_rgb_blend(texels, fragments, blend_params) + +def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"): + if len(mesh) != len(cameras): + if len(cameras) % len(mesh) == 0: + mesh = mesh.extend(len(cameras)) + else: + raise NotImplementedError() + + # render requires everything in float16 or float32 + input_dtype = dtype + blend_params = BlendParams(1e-4, 1e-4, bkgd) + + # Define the settings for rasterization and shading + raster_settings = RasterizationSettings( + image_size=(H, W), + blur_radius=blur_radius, + faces_per_pixel=faces_per_pixel, + clip_barycentric_coords=True, + bin_size=None, + max_faces_per_bin=None, + ) + + # Create a renderer by composing a rasterizer and a shader + # We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used) + renderer = MeshRendererWithFragments( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=raster_settings + ), + shader=VertexColorShader( + device=device, + cameras=cameras, + blend_params=blend_params + ) + ) + + # render RGB and depth, get mask + with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type): + images, _ = renderer(mesh) + return images # BHW4 + +class Pytorch3DNormalsRenderer: # 100 times slower!!! + def __init__(self, cameras, image_size, device): + self.cameras = cameras.to(device) + self._image_size = image_size + self.device = device + + def render(self, + vertices: torch.Tensor, #V,3 float + normals: torch.Tensor, #V,3 float in [-1, 1] + faces: torch.Tensor, #F,3 long + ) ->torch.Tensor: #C,H,W,4 + mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device) + return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device) + +def save_tensor_to_img(tensor, save_dir): + from PIL import Image + import numpy as np + for idx, img in enumerate(tensor): + img = img[..., :3].cpu().numpy() + img = (img * 255).astype(np.uint8) + img = Image.fromarray(img) + img.save(save_dir + f"{idx}.png") + +if __name__ == "__main__": + import sys + import os + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d + cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0) + mv,proj = make_star_cameras_orthographic(4, 1) + resolution = 1024 + renderer1 = NormalsRenderer(mv,proj, [resolution,resolution], device="cuda") + renderer2 = Pytorch3DNormalsRenderer(cameras, [resolution,resolution], device="cuda") + vertices = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[1,0,0]], device="cuda", dtype=torch.float32) + normals = torch.tensor([[-1,-1,-1],[1,-1,-1],[-1,-1,1],[-1,1,-1]], device="cuda", dtype=torch.float32) + faces = torch.tensor([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], device="cuda", dtype=torch.long) + + import time + t0 = time.time() + r1 = renderer1.render(vertices, normals, faces) + print("time r1:", time.time() - t0) + + t0 = time.time() + r2 = renderer2.render(vertices, normals, faces) + print("time r2:", time.time() - t0) + + for i in range(4): + print((r1[i]-r2[i]).abs().mean(), (r1[i]+r2[i]).abs().mean()) + + +def calc_face_normals( + vertices:torch.Tensor, #V,3 first vertex may be unreferenced + faces:torch.Tensor, #F,3 long, first face may be all zero + normalize:bool=False, + )->torch.Tensor: #F,3 + """ + n + | + c0 corners ordered counterclockwise when + / \ looking onto surface (in neg normal direction) + c1---c2 + """ + full_vertices = vertices[faces] #F,C=3,3 + v0,v1,v2 = full_vertices.unbind(dim=1) #F,3 + face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3 + if normalize: + face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) + return face_normals #F,3 + + +def calc_vertex_normals( + vertices:torch.Tensor, #V,3 first vertex may be unreferenced + faces:torch.Tensor, #F,3 long, first face may be all zero + face_normals:torch.Tensor=None, #F,3, not normalized + )->torch.Tensor: #F,3 + + F = faces.shape[0] + + if face_normals is None: + face_normals = calc_face_normals(vertices,faces) + + vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3 + vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3)) + vertex_normals = vertex_normals.sum(dim=1) #V,3 + return tfunc.normalize(vertex_normals, eps=1e-6, dim=1) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7763e9be2d4546561b17f0e36a58406c9e2b31b0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,24 @@ +omegaconf +rm_anime_bg +diffusers==0.27.2 +transformers==4.42.4 +einops +huggingface_hub==0.25.0 +opencv-python +accelerate +matplotlib +kornia +imageio +imageio-ffmpeg +xatlas +trimesh +rembg +onnxruntime +scikit-learn +pygltflib +pymeshlab==2022.2.post3 +pytorch_lightning +git+https://github.com/NVlabs/nvdiffrast +git+https://github.com/facebookresearch/pytorch3d.git +git+https://github.com/Baijiong-Lin/LoRA-Torch +git+https://github.com/facebookresearch/segment-anything.git diff --git a/slrm/__init__.py b/slrm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/slrm/cameras.npy b/slrm/cameras.npy new file mode 100644 index 0000000000000000000000000000000000000000..dcdc0d56cc53908349469d32e3174b6c0a1dcb48 --- /dev/null +++ b/slrm/cameras.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:050ce22f48bf06d9458bff03bcb21f53f5130ef5a2158e7affcb72d36b66aa98 +size 320 diff --git a/slrm/models/decoder/__init__.py b/slrm/models/decoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/slrm/models/decoder/transformer.py b/slrm/models/decoder/transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..a2d57acdc1fd0e697a90b389f221bb1cca7766d4 --- /dev/null +++ b/slrm/models/decoder/transformer.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +import loratorch as lora + + +class BasicTransformerBlock(nn.Module): + """ + Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. + """ + # use attention from torch.nn.MultiHeadAttention + # Block contains a cross-attention layer, a self-attention layer, and a MLP + def __init__( + self, + inner_dim: int, + cond_dim: int, + num_heads: int, + eps: float, + attn_drop: float = 0., + attn_bias: bool = False, + mlp_ratio: float = 4., + mlp_drop: float = 0., + lora_rank: int = 0, + ): + super().__init__() + + self.norm1 = nn.LayerNorm(inner_dim) + self.cross_attn = lora.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, + dropout=attn_drop, bias=attn_bias, batch_first=True, r=lora_rank) + self.norm2 = nn.LayerNorm(inner_dim) + self.self_attn = lora.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True, r=lora_rank) + self.norm3 = nn.LayerNorm(inner_dim) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x, cond): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] + x = x + self.cross_attn(self.norm1(x), cond, cond)[0] + before_sa = self.norm2(x) + x = x + self.self_attn(before_sa, before_sa, before_sa)[0] + x = x + self.mlp(self.norm3(x)) + return x + + +class TriplaneTransformer(nn.Module): + """ + Transformer with condition that generates a triplane representation. + + Reference: + Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486 + """ + def __init__( + self, + inner_dim: int, + image_feat_dim: int, + triplane_low_res: int, + triplane_high_res: int, + triplane_dim: int, + num_layers: int, + num_heads: int, + eps: float = 1e-6, + lora_rank: int = 0, + ): + super().__init__() + + # attributes + self.triplane_low_res = triplane_low_res + self.triplane_high_res = triplane_high_res + self.triplane_dim = triplane_dim + + # modules + # initialize pos_embed with 1/sqrt(dim) * N(0, 1) + self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5) + self.layers = nn.ModuleList([ + BasicTransformerBlock( + inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps, lora_rank=lora_rank) + for _ in range(num_layers) + ]) + self.norm = nn.LayerNorm(inner_dim, eps=eps) + self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0) + + def forward(self, image_feats): + # image_feats: [N, L_cond, D_cond] + + N = image_feats.shape[0] + H = W = self.triplane_low_res + L = 3 * H * W + + x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] + for layer in self.layers: + x = layer(x, image_feats) + x = self.norm(x) + + # separate each plane and apply deconv + x = x.view(N, 3, H, W, -1) + x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W] + x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] + x = self.deconv(x) # [3*N, D', H', W'] + x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] + x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] + x = x.contiguous() + + return x diff --git a/slrm/models/encoder/__init__.py b/slrm/models/encoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/slrm/models/encoder/dino.py b/slrm/models/encoder/dino.py new file mode 100755 index 0000000000000000000000000000000000000000..684444cab2a13979bcd5688069e9f7729d4ca784 --- /dev/null +++ b/slrm/models/encoder/dino.py @@ -0,0 +1,550 @@ +# coding=utf-8 +# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch ViT model.""" + + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, +) +from transformers import PreTrainedModel, ViTConfig +from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer + + +class ViTEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = ViTPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class ViTPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class ViTSelfAttention(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class ViTSelfOutput(nn.Module): + """ + The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class ViTAttention(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.attention = ViTSelfAttention(config) + self.output = ViTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ViTIntermediate(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class ViTOutput(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class ViTLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ViTAttention(config) + self.intermediate = ViTIntermediate(config) + self.output = ViTOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True) + ) + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + + def forward( + self, + hidden_states: torch.Tensor, + adaln_input: torch.Tensor = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) + + self_attention_outputs = self.attention( + modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class ViTEncoder(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + adaln_input: torch.Tensor = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + adaln_input, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ViTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["ViTEmbeddings", "ViTLayer"] + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ViTEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + +class ViTModel(ViTPreTrainedModel): + def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): + super().__init__(config) + self.config = config + + self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = ViTEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = ViTPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ViTPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + adaln_input: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + adaln_input=adaln_input, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class ViTPooler(nn.Module): + def __init__(self, config: ViTConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output \ No newline at end of file diff --git a/slrm/models/encoder/dino_wrapper.py b/slrm/models/encoder/dino_wrapper.py new file mode 100755 index 0000000000000000000000000000000000000000..e84fd51e7dfcfd1a969b763f5a49aeb7f608e6f9 --- /dev/null +++ b/slrm/models/encoder/dino_wrapper.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn +from transformers import ViTImageProcessor +from einops import rearrange, repeat +from .dino import ViTModel + + +class DinoWrapper(nn.Module): + """ + Dino v1 wrapper using huggingface transformer implementation. + """ + def __init__(self, model_name: str, freeze: bool = True): + super().__init__() + self.model, self.processor = self._build_dino(model_name) + self.camera_embedder = nn.Sequential( + nn.Linear(16, self.model.config.hidden_size, bias=True), + nn.SiLU(), + nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True) + ) + if freeze: + self._freeze() + + def forward(self, image, camera): + # image: [B, N, C, H, W] + # camera: [B, N, D] + # RGB image with [0,1] scale and properly sized + if image.ndim == 5: + image = rearrange(image, 'b n c h w -> (b n) c h w') + dtype = image.dtype + inputs = self.processor( + images=image.float(), + return_tensors="pt", + do_rescale=False, + do_resize=False, + ).to(self.model.device).to(dtype) + # embed camera + N = camera.shape[1] + camera_embeddings = self.camera_embedder(camera) + camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d') + embeddings = camera_embeddings + # This resampling of positional embedding uses bicubic interpolation + outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True) + last_hidden_states = outputs.last_hidden_state + return last_hidden_states + + def _freeze(self): + print(f"======== Freezing DinoWrapper ========") + self.model.eval() + for name, param in self.model.named_parameters(): + param.requires_grad = False + + @staticmethod + def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): + import requests + try: + model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) + processor = ViTImageProcessor.from_pretrained(model_name) + return model, processor + except requests.exceptions.ProxyError as err: + if proxy_error_retries > 0: + print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...") + import time + time.sleep(proxy_error_cooldown) + return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) + else: + raise err diff --git a/slrm/models/geometry/__init__.py b/slrm/models/geometry/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..89e9a6c2fffe82a55693885dae78c1a630924389 --- /dev/null +++ b/slrm/models/geometry/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. diff --git a/slrm/models/geometry/camera/__init__.py b/slrm/models/geometry/camera/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..c5c7082e47c65a08e25489b3c3fd010d07ad9758 --- /dev/null +++ b/slrm/models/geometry/camera/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from torch import nn + + +class Camera(nn.Module): + def __init__(self): + super(Camera, self).__init__() + pass diff --git a/slrm/models/geometry/camera/perspective_camera.py b/slrm/models/geometry/camera/perspective_camera.py new file mode 100755 index 0000000000000000000000000000000000000000..dc84d91ec006f894a5a7c06e84d5a02a63483138 --- /dev/null +++ b/slrm/models/geometry/camera/perspective_camera.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from . import Camera +import numpy as np + + +def projection(x=0.1, n=1.0, f=50.0, near_plane=None): + if near_plane is None: + near_plane = n + return np.array( + [[n / x, 0, 0, 0], + [0, n / -x, 0, 0], + [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)], + [0, 0, -1, 0]]).astype(np.float32) + + +def ortho_projection(n=1.0, f=50.0): + return np.array( + [[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -2 / (f - n), -(f + n) / (f - n)], + [0, 0, 0, 1]]).astype(np.float32) + + +class PerspectiveCamera(Camera): + def __init__(self, fovy=49.0, device='cuda'): + super(PerspectiveCamera, self).__init__() + self.device = device + focal = np.tan(fovy / 180.0 * np.pi * 0.5) + self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0) + + def project(self, points_bxnx4): + out = torch.matmul( + points_bxnx4, + torch.transpose(self.proj_mtx, 1, 2)) + return out + + +class OrthogonalCamera(Camera): + def __init__(self, device='cuda'): + super(OrthogonalCamera, self).__init__() + self.device = device + self.proj_mtx = torch.from_numpy(ortho_projection(f=1000.0, n=0.1)).to(self.device).unsqueeze(dim=0) + + def project(self, points_bxnx4, ortho_scales_bx1): + out = torch.matmul( + points_bxnx4, + torch.transpose(self.proj_mtx, 1, 2)) + # print(ortho_scales_bx1) + out[:, :, 0] = out[:, :, 0] / ortho_scales_bx1[..., None] * 2. + out[:, :, 1] = out[:, :, 1] / ortho_scales_bx1[..., None] * 2. + return out diff --git a/slrm/models/geometry/render/__init__.py b/slrm/models/geometry/render/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..483cfabbf395853f1ca3e67b856d5f17b9889d1b --- /dev/null +++ b/slrm/models/geometry/render/__init__.py @@ -0,0 +1,8 @@ +import torch + +class Renderer(): + def __init__(self): + pass + + def forward(self): + pass \ No newline at end of file diff --git a/slrm/models/geometry/render/neural_render.py b/slrm/models/geometry/render/neural_render.py new file mode 100755 index 0000000000000000000000000000000000000000..bcf3e03a61d79f1562acf0300ef7f12587de33f4 --- /dev/null +++ b/slrm/models/geometry/render/neural_render.py @@ -0,0 +1,130 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import torch.nn.functional as F +import nvdiffrast.torch as dr +from . import Renderer +from ..camera.perspective_camera import PerspectiveCamera, OrthogonalCamera + +_FG_LUT = None + + +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate( + attr.contiguous(), rast, attr_idx, rast_db=rast_db, + diff_attrs=None if rast_db is None else 'all') + + +def xfm_points(points, matrix, use_python=True): + '''Transform points. + Args: + points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + Returns: + Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" + return out + + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + + +def compute_vertex_normal(v_pos, t_pos_idx): + i0 = t_pos_idx[:, 0] + i1 = t_pos_idx[:, 1] + i2 = t_pos_idx[:, 2] + + v0 = v_pos[i0, :] + v1 = v_pos[i1, :] + v2 = v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where( + dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) + ) + v_nrm = F.normalize(v_nrm, dim=1) + assert torch.all(torch.isfinite(v_nrm)) + + return v_nrm + + +class NeuralRender(Renderer): + def __init__(self, device='cuda', camera_model=None): + super(NeuralRender, self).__init__() + self.device = device + self.ctx = dr.RasterizeCudaContext(device=device) + self.projection_mtx = None + self.camera = camera_model + + def render_mesh( + self, + mesh_v_pos_bxnx3, + mesh_t_pos_idx_fx3, + camera_mv_bx4x4, + mesh_v_feat_bxnxd, + resolution=256, + spp=1, + device='cuda', + hierarchical_mask=False, + dtype=None + ): + assert not hierarchical_mask + + mtx_in = torch.tensor(camera_mv_bx4x4, dtype=mesh_v_pos_bxnx3.dtype, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4.to(mesh_v_pos_bxnx3) + + if isinstance(self.camera, PerspectiveCamera): + v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates + v_pos_clip = self.camera.project(v_pos) # Projection in the camera + elif isinstance(self.camera, OrthogonalCamera): + ortho_scale_in = mtx_in[..., -1] + mtx_in = mtx_in[..., :-1].reshape(-1, 4, 4) + v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) + v_pos_clip = self.camera.project(v_pos, ortho_scale_in) + + v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates + + # Render the image, + # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render + num_layers = 1 + mask_pyramid = None + assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes + mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos + + with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) + + hard_mask = torch.clamp(rast[..., -1:], 0, 1) + antialias_mask = dr.antialias( + hard_mask.clone().contiguous(), rast, v_pos_clip, + mesh_t_pos_idx_fx3).to(dtype) + + depth = gb_feat[..., -2:-1] + ori_mesh_feature = gb_feat[..., :-4] + + normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) + normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) + normal = F.normalize(normal, dim=-1) + normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.to(normal.dtype)) # black background + + return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal diff --git a/slrm/models/geometry/rep_3d/__init__.py b/slrm/models/geometry/rep_3d/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a3d5628a8433298477d1963f92578d47106b4a0f --- /dev/null +++ b/slrm/models/geometry/rep_3d/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np + + +class Geometry(): + def __init__(self): + pass + + def forward(self): + pass diff --git a/slrm/models/geometry/rep_3d/dmtet.py b/slrm/models/geometry/rep_3d/dmtet.py new file mode 100755 index 0000000000000000000000000000000000000000..b6a709380abac0bbf66fd1c8582485f3982223e4 --- /dev/null +++ b/slrm/models/geometry/rep_3d/dmtet.py @@ -0,0 +1,504 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np +import os +from . import Geometry +from .dmtet_utils import get_center_boundary_index +import torch.nn.functional as F + + +############################################################################### +# DMTet utility functions +############################################################################### +def create_mt_variable(device): + triangle_table = torch.tensor( + [ + [-1, -1, -1, -1, -1, -1], + [1, 0, 2, -1, -1, -1], + [4, 0, 3, -1, -1, -1], + [1, 4, 2, 1, 3, 4], + [3, 1, 5, -1, -1, -1], + [2, 3, 0, 2, 5, 3], + [1, 4, 0, 1, 5, 4], + [4, 2, 5, -1, -1, -1], + [4, 5, 2, -1, -1, -1], + [4, 1, 0, 4, 5, 1], + [3, 2, 0, 3, 5, 2], + [1, 3, 5, -1, -1, -1], + [4, 1, 2, 4, 3, 1], + [3, 0, 4, -1, -1, -1], + [2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] + ], dtype=torch.long, device=device) + + num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device) + base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device) + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device)) + return triangle_table, num_triangles_table, base_tet_edges, v_id + + +def sort_edges(edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() + order = order.unsqueeze(dim=1) + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1 - order, dim=1) + return torch.stack([a, b], -1) + + +############################################################################### +# marching tetrahedrons (differentiable) +############################################################################### + +def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] # .long() + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], dim=1, + index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], dim=1, + index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + return verts, faces + + +def create_tetmesh_variables(device='cuda'): + tet_table = torch.tensor( + [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1], + [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1], + [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8], + [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1], + [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9], + [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9], + [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9], + [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1], + [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9], + [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9], + [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9], + [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8], + [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8], + [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6], + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device) + num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device) + return tet_table, num_tets_table + + +def marching_tets_tetmesh( + pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id, + return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] # .long() + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], dim=1, + index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], dim=1, + index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + if not return_tet_mesh: + return verts, faces + occupied_verts = ori_v[occ_n] + mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1 + mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda") + tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4)) + + idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10 + tet_verts = torch.cat([verts, occupied_verts], 0) + num_tets = num_tets_table[tetindex] + + tets = torch.cat( + ( + torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape( + -1, + 4), + torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape( + -1, + 4), + ), dim=0) + # add fully occupied tets + fully_occupied = occ_fx4.sum(-1) == 4 + tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0] + tets = torch.cat([tets, tet_fully_occupied]) + + return verts, faces, tet_verts, tets + + +############################################################################### +# Compact tet grid +############################################################################### + +def compact_tets(pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + # Find surface tets + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets + + valid_vtx = tet_fx4[valid_tets].reshape(-1) + unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True) + new_pos = pos_nx3[unique_vtx] + new_sdf = sdf_n[unique_vtx] + new_tets = idx_map.reshape(-1, 4) + return new_pos, new_sdf, new_tets + + +############################################################################### +# Subdivide volume +############################################################################### + +def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf): + device = tet_pos_bxnx3.device + # get new verts + tet_fx4 = tet_bxfx4[0] + edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3] + all_edges = tet_fx4[:, edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + idx_map = idx_map + tet_pos_bxnx3.shape[1] + all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1) + mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape( + all_values.shape[0], -1, 2, + all_values.shape[-1]).mean(2) + new_v = torch.cat([all_values, mid_points_pos], 1) + new_v, new_sdf = new_v[..., :3], new_v[..., 3] + + # get new tets + + idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3] + idx_ab = idx_map[0::6] + idx_ac = idx_map[1::6] + idx_ad = idx_map[2::6] + idx_bc = idx_map[3::6] + idx_bd = idx_map[4::6] + idx_cd = idx_map[5::6] + + tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1) + tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1) + tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1) + tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1) + tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1) + tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1) + tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1) + tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1) + + tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0) + tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1) + tet = tet_np.long().to(device) + + return new_v, tet, new_sdf + + +############################################################################### +# Adjacency +############################################################################### +def tet_to_tet_adj_sparse(tet_tx4): + # include self connection!!!!!!!!!!!!!!!!!!! + with torch.no_grad(): + t = tet_tx4.shape[0] + device = tet_tx4.device + idx_array = torch.LongTensor( + [0, 1, 2, + 1, 0, 3, + 2, 3, 0, + 3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3) + + # get all faces + all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape( + -1, + 3) # (tx4, 3) + all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1) + # sort and group + all_faces_sorted, _ = torch.sort(all_faces, dim=1) + + all_faces_unique, inverse_indices, counts = torch.unique( + all_faces_sorted, dim=0, return_counts=True, + return_inverse=True) + tet_face_fx3 = all_faces_unique[counts == 2] + counts = counts[inverse_indices] # tx4 + valid = (counts == 2) + + group = inverse_indices[valid] + # print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape) + _, indices = torch.sort(group) + all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices] + tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1) + + tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])]) + adj_self = torch.arange(t, device=tet_tx4.device) + adj_self = torch.stack([adj_self, adj_self], -1) + tet_adj_idx = torch.cat([tet_adj_idx, adj_self]) + + tet_adj_idx = torch.unique(tet_adj_idx, dim=0) + values = torch.ones( + tet_adj_idx.shape[0], device=tet_tx4.device).float() + adj_sparse = torch.sparse.FloatTensor( + tet_adj_idx.t(), values, torch.Size([t, t])) + + # normalization + neighbor_num = 1.0 / torch.sparse.sum( + adj_sparse, dim=1).to_dense() + values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0]) + adj_sparse = torch.sparse.FloatTensor( + tet_adj_idx.t(), values, torch.Size([t, t])) + return adj_sparse + + +############################################################################### +# Compact grid +############################################################################### + +def get_tet_bxfx4x3(bxnxz, bxfx4): + n_batch, z = bxnxz.shape[0], bxnxz.shape[2] + gather_input = bxnxz.unsqueeze(2).expand( + n_batch, bxnxz.shape[1], 4, z) + gather_index = bxfx4.unsqueeze(-1).expand( + n_batch, bxfx4.shape[1], 4, z).long() + tet_bxfx4xz = torch.gather( + input=gather_input, dim=1, index=gather_index) + + return tet_bxfx4xz + + +def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf): + with torch.no_grad(): + assert tet_pos_bxnx3.shape[0] == 1 + + occ = grid_sdf[0] > 0 + occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1) + mask = (occ_sum > 0) & (occ_sum < 4) + + # build connectivity graph + adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0]) + mask = mask.float().unsqueeze(-1) + + # Include a one ring of neighbors + for i in range(1): + mask = torch.sparse.mm(adj_matrix, mask) + mask = mask.squeeze(-1) > 0 + + mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long) + new_tet_bxfx4 = tet_bxfx4[:, mask].long() + selected_verts_idx = torch.unique(new_tet_bxfx4) + new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx] + mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device) + new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape) + new_grid_sdf = grid_sdf[:, selected_verts_idx] + return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf + + +############################################################################### +# Regularizer +############################################################################### + +def sdf_reg_loss(sdf, all_edges): + sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 0], + (sdf_f1x6x2[..., 1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 1], + (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +def sdf_reg_loss_batch(sdf, all_edges): + sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +############################################################################### +# Geometry interface +############################################################################### +class DMTetGeometry(Geometry): + def __init__( + self, grid_res=64, scale=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(DMTetGeometry, self).__init__() + self.grid_res = grid_res + self.device = device + self.args = args + tets = np.load('data/tets/%d_compress.npz' % (grid_res)) + self.verts = torch.from_numpy(tets['vertices']).float().to(self.device) + # Make sure the tet is zero-centered and length is equal to 1 + length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0] + length = length.max() + mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0 + self.verts = (self.verts - mid.unsqueeze(dim=0)) / length + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + self.indices = torch.from_numpy(tets['tets']).long().to(self.device) + self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device) + self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device) + # Parameters for regularization computation + edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device) + all_edges = self.indices[:, edges].reshape(-1, 2) + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts) + self.renderer = renderer + self.render_type = render_type + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): + if indices is None: + indices = self.indices + verts, faces = marching_tets( + v_deformed_nx3, sdf_n, indices, self.triangle_table, + self.num_triangles_table, self.base_tet_edges, self.v_id) + faces = torch.cat( + [faces[:, 0:1], + faces[:, 2:3], + faces[:, 1:2], ], dim=-1) + return verts, faces + + def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): + if indices is None: + indices = self.indices + verts, faces, tet_verts, tets = marching_tets_tetmesh( + v_deformed_nx3, sdf_n, indices, self.triangle_table, + self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True, + num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3) + faces = torch.cat( + [faces[:, 0:1], + faces[:, 2:3], + faces[:, 1:2], ], dim=-1) + return verts, faces, tet_verts, tets + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): + return_value = dict() + if self.render_type == 'neural_render': + tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh( + mesh_v_nx3.unsqueeze(dim=0), + mesh_f_fx3.int(), + camera_mv_bx4x4, + mesh_v_nx3.unsqueeze(dim=0), + resolution=resolution, + device=self.device, + hierarchical_mask=hierarchical_mask + ) + + return_value['tex_pos'] = tex_pos + return_value['mask'] = mask + return_value['hard_mask'] = hard_mask + return_value['rast'] = rast + return_value['v_pos_clip'] = v_pos_clip + return_value['mask_pyramid'] = mask_pyramid + return_value['depth'] = depth + else: + raise NotImplementedError + + return return_value + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/slrm/models/geometry/rep_3d/dmtet_utils.py b/slrm/models/geometry/rep_3d/dmtet_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..8d466a9e78c49d947c115707693aa18d759885ad --- /dev/null +++ b/slrm/models/geometry/rep_3d/dmtet_utils.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch + + +def get_center_boundary_index(verts): + length_ = torch.sum(verts ** 2, dim=-1) + center_idx = torch.argmin(length_) + boundary_neg = verts == verts.max() + boundary_pos = verts == verts.min() + boundary = torch.bitwise_or(boundary_pos, boundary_neg) + boundary = torch.sum(boundary.float(), dim=-1) + boundary_idx = torch.nonzero(boundary) + return center_idx, boundary_idx.squeeze(dim=-1) diff --git a/slrm/models/geometry/rep_3d/extract_texture_map.py b/slrm/models/geometry/rep_3d/extract_texture_map.py new file mode 100755 index 0000000000000000000000000000000000000000..a5d62bb5a6c5cdf632fb504db3d2dfa99a3abbd3 --- /dev/null +++ b/slrm/models/geometry/rep_3d/extract_texture_map.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import xatlas +import numpy as np +import nvdiffrast.torch as dr + + +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): + vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + return uvs, mesh_tex_idx, gb_pos, mask diff --git a/slrm/models/geometry/rep_3d/flexicubes.py b/slrm/models/geometry/rep_3d/flexicubes.py new file mode 100755 index 0000000000000000000000000000000000000000..8e612430ea422cfe915b45136bd34fd7843bd691 --- /dev/null +++ b/slrm/models/geometry/rep_3d/flexicubes.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +from .tables import * + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + """ + This class implements the FlexiCubes method for extracting meshes from scalar fields. + It maintains a series of lookup tables and indices to support the mesh extraction process. + FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances + the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting + the surface representation through gradient-based optimization. + + During instantiation, the class loads DMC tables from a file and transforms them into + PyTorch tensors on the specified device. + + Attributes: + device (str): Specifies the computational device (default is "cuda"). + dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges + associated with each dual vertex in 256 Marching Cubes (MC) configurations. + num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of + the 256 MC configurations. + check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 + of the DMC configurations. + tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface. + quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles + along one diagonal. + quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into + two triangles along the other diagonal. + quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles + during training by connecting all edges to their midpoints. + cube_corners (torch.Tensor): Defines the positions of a standard unit cube's + eight corners in 3D space, ordered starting from the origin (0,0,0), + moving along the x-axis, then y-axis, and finally z-axis. + Used as a blueprint for generating a voxel grid. + cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used + to retrieve the case id. + cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. + Used to retrieve edge vertices in DMC. + edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with + their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the + first edge is oriented along the x-axis. + dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges + across four adjacent cubes to the shared faces of these cubes. For instance, + dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along + the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. + This tensor is only utilized during isosurface tetrahedralization. + adj_pairs (torch.Tensor): + A tensor containing index pairs that correspond to neighboring cubes that share the same edge. + qef_reg_scale (float): + The scaling factor applied to the regularization loss to prevent issues with singularity + when solving the QEF. This parameter is only used when a 'grad_func' is specified. + weight_scale (float): + The scale of weights in FlexiCubes. Should be between 0 and 1. + """ + + def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + self.qef_reg_scale = qef_reg_scale + self.weight_scale = weight_scale + + def construct_voxel_grid(self, res): + """ + Generates a voxel grid based on the specified resolution. + + Args: + res (int or list[int]): The resolution of the voxel grid. If an integer + is provided, it is used for all three dimensions. If a list or tuple + of 3 integers is provided, they define the resolution for the x, + y, and z dimensions respectively. + + Returns: + (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the + cube corners (index into vertices) of the constructed voxel grid. + The vertices are centered at the origin, with the length of each + dimension in the grid being one. + """ + base_cube_f = torch.arange(8).to(self.device) + if isinstance(res, int): + res = (res, res, res) + voxel_grid_template = torch.ones(res, device=self.device) + + res = torch.tensor([res], dtype=torch.float, device=self.device) + coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 + verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3) + cubes = (base_cube_f.unsqueeze(0) + + torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1) + + verts_rounded = torch.round(verts * 10**5) / (10**5) + verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) + cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) + + return verts_unique - 0.5, cubes + + def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, + gamma_f=None, training=False, output_tetmesh=False, grad_func=None): + r""" + Main function for mesh extraction from scalar field using FlexiCubes. This function converts + discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, + to triangle or tetrahedral meshes using a differentiable operation as described in + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances + mesh quality and geometric fidelity by adjusting the surface representation based on gradient + optimization. The output surface is differentiable with respect to the input vertex positions, + scalar field values, and weight parameters. + + If you intend to extract a surface mesh from a fixed Signed Distance Field without the + optimization of parameters, it is suggested to provide the "grad_func" which should + return the surface gradient at any given 3D position. When grad_func is provided, the process + to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as + described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. + Please note, this approach is non-differentiable. + + For more details and example usage in optimization, refer to the + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. + + Args: + x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. + s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values + denote that the corresponding vertex resides inside the isosurface. This affects + the directions of the extracted triangle faces and volume to be tetrahedralized. + cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. + res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it + is used for all three dimensions. If a list or tuple of 3 integers is provided, they + specify the resolution for the x, y, and z dimensions respectively. + beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual + vertices positioning. Defaults to uniform value for all edges. + alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual + vertices positioning. Defaults to uniform value for all vertices. + gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of + quadrilaterals into triangles. Defaults to uniform value for all cubes. + training (bool, optional): If set to True, applies differentiable quad splitting for + training. Defaults to False. + output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, + outputs a triangular mesh. Defaults to False. + grad_func (callable, optional): A function to compute the surface gradient at specified + 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 + tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. + + Returns: + (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: + - Vertices for the extracted triangular/tetrahedral mesh. + - Faces for the extracted triangular/tetrahedral mesh. + - Regularizer L_dev, computed per dual vertex. + + .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: + https://research.nvidia.com/labs/toronto-ai/flexicubes/ + .. _Manifold Dual Contouring: + https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf + """ + + surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) + if surf_cubes.sum() == 0: + return torch.zeros( + (0, 3), + device=self.device), torch.zeros( + (0, 4), + dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros( + (0, 3), + dtype=torch.long, device=self.device), torch.zeros( + (0), + device=self.device) + beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, res) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) + + vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( + x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func) + vertices, faces, s_edges, edge_indices = self._triangulate( + s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func) + if not output_tetmesh: + return vertices, faces, L_dev + else: + vertices, tets = self._tetrahedralize( + x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training) + return vertices, tets, L_dev + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).to(dist) + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta_fx12 is not None: + beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1) + else: + beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha_fx8 is not None: + alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1) + else: + alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = s_n < 0 + all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, s_n, cube_fx8): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + if grad_func is not None: + normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) + vd = [] + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + + if grad_func is not None: + with torch.no_grad(): + cube_e_verts_idx = idx_map[cur_cubes] + curr_edge_group[~curr_mask] = 0 + + verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group) + verts_group_idx[verts_group_idx == -1] = 0 + verts_group_pos = torch.index_select( + input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3) + v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1) + curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) + verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2)) + + normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape( + -1, num.item(), 7, + 3) + curr_mask = curr_mask.squeeze(2) + vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, + verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3)) + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + + if grad_func is not None: + vd = torch.cat(vd) + L_dev = torch.zeros([1], device=self.device) + else: + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device, dtype=beta_fx12.dtype) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map + + def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + if grad_func is not None: + # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. + with torch.no_grad(): + vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) + gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) + else: + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor( + 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1) + gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor( + 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1) + if not training: + mask = (gamma_02 > gamma_13).squeeze(1) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 + vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / + weight_sum.unsqueeze(-1)).squeeze(1) + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices + + def _tetrahedralize( + self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training): + """ + Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + occ_sum = torch.sum(occ_fx8, -1) + + inside_verts = x_nx3[occ_n] + mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] + """ + For each grid edge connecting two grid vertices with different + signs, we first form a four-sided pyramid by connecting one + of the grid vertices with four mesh vertices that correspond + to the grid edge and then subdivide the pyramid into two tetrahedra + """ + inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[ + s_edges < 0]] + if not training: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) + else: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) + + tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) + """ + For each grid edge connecting two grid vertices with the + same sign, the tetrahedron is formed by the two grid vertices + and two vertices in consecutive adjacent cells + """ + inside_cubes = (occ_sum == 8) + inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) + inside_cubes_center_idx = torch.arange( + inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] + + surface_n_inside_cubes = surf_cubes | inside_cubes + edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13), + dtype=torch.long, device=x_nx3.device) * -1 + surf_cubes = surf_cubes[surface_n_inside_cubes] + inside_cubes = inside_cubes[surface_n_inside_cubes] + edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) + edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx + + all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 + mask = mask_edges[_idx_map] + counts = counts[_idx_map] + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) + idx_map = mapping[_idx_map] + + group_mask = (counts == 4) & mask + group = idx_map.reshape(-1)[group_mask] + edge_indices, indices = torch.sort(group) + cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long, + device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask] + edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze( + 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask] + # Identify the face shared by the adjacent cells. + cube_idx_4 = cube_idx[indices].reshape(-1, 4) + edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] + shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) + cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) + # Identify an edge of the face with different signs and + # select the mesh vertex corresponding to the identified edge. + case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255 + case_ids_expand[surf_cubes] = case_ids + cases = case_ids_expand[cube_idx_4x2] + quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) + mask = (quad_edge == -1).sum(-1) == 0 + inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) + tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] + + tets = torch.cat([tets_surface, tets_inside]) + vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) + return vertices, tets diff --git a/slrm/models/geometry/rep_3d/flexicubes_geometry.py b/slrm/models/geometry/rep_3d/flexicubes_geometry.py new file mode 100755 index 0000000000000000000000000000000000000000..e2028360d434d5544178bd123d396bdf55920e97 --- /dev/null +++ b/slrm/models/geometry/rep_3d/flexicubes_geometry.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np +import os +from . import Geometry +from .flexicubes import FlexiCubes # replace later +from .dmtet import sdf_reg_loss_batch +import torch.nn.functional as F + +def get_center_boundary_index(grid_res_xy, grid_res_z, device): + v = torch.zeros((grid_res_xy + 1, grid_res_xy + 1, grid_res_z + 1), dtype=torch.bool, device=device) + v[grid_res_xy // 2 + 1, grid_res_xy // 2 + 1, grid_res_z // 2 + 1] = True + center_indices = torch.nonzero(v.reshape(-1)) + + v[grid_res_xy // 2 + 1, grid_res_xy // 2 + 1, grid_res_z // 2 + 1] = False + v[:2, ...] = True + v[-2:, ...] = True + v[:, :2, ...] = True + v[:, -2:, ...] = True + v[:, :, :2] = True + v[:, :, -2:] = True + boundary_indices = torch.nonzero(v.reshape(-1)) + return center_indices, boundary_indices + +############################################################################### +# Geometry interface +############################################################################### +class FlexiCubesGeometry(Geometry): + def __init__( + self, grid_res_xy=64, grid_res_z=64, + scale_xy=2.0, scale_z=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(FlexiCubesGeometry, self).__init__() + self.grid_res_xy = grid_res_xy + self.grid_res_z = grid_res_z + self.device = device + self.args = args + self.fc = FlexiCubes(device, weight_scale=0.5) + self.verts, self.indices = self.fc.construct_voxel_grid([grid_res_xy, grid_res_xy, grid_res_z]) + # if isinstance(scale, list): + # self.verts[:, 0] = self.verts[:, 0] * scale[0] + # self.verts[:, 1] = self.verts[:, 1] * scale[1] + # self.verts[:, 2] = self.verts[:, 2] * scale[1] + # else: + # self.verts = self.verts * scale + self.verts[:, 0] = self.verts[:, 0] * scale_xy + self.verts[:, 1] = self.verts[:, 1] * scale_xy + self.verts[:, 2] = self.verts[:, 2] * scale_z + + all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) + self.all_edges = torch.unique(all_edges, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res_xy, self.grid_res_z, device) + self.renderer = renderer + self.render_type = render_type + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): + if indices is None: + indices = self.indices + + verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, [self.grid_res_xy, self.grid_res_xy, self.grid_res_z], + beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], + gamma_f=weight_n[:, 20], training=is_training + ) + return verts, faces, v_reg_loss + + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False, dtype=None): + return_value = dict() + if self.render_type == 'neural_render': + tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh( + mesh_v_nx3.unsqueeze(dim=0), + mesh_f_fx3.int(), + camera_mv_bx4x4, + mesh_v_nx3.unsqueeze(dim=0), + resolution=resolution, + device=self.device, + hierarchical_mask=hierarchical_mask, + dtype=dtype + ) + + return_value['tex_pos'] = tex_pos + return_value['mask'] = mask + return_value['hard_mask'] = hard_mask + return_value['rast'] = rast + return_value['v_pos_clip'] = v_pos_clip + return_value['mask_pyramid'] = mask_pyramid + return_value['depth'] = depth + return_value['normal'] = normal + else: + raise NotImplementedError + + return return_value + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/slrm/models/geometry/rep_3d/tables.py b/slrm/models/geometry/rep_3d/tables.py new file mode 100755 index 0000000000000000000000000000000000000000..5873e7727b5595a1e4fbc3bd10ae5be8f3d06cca --- /dev/null +++ b/slrm/models/geometry/rep_3d/tables.py @@ -0,0 +1,791 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +dmc_table = [ +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] +] +num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, +2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, +1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, +1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, +2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, +3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, +2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, +1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, +1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +check_table = [ +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 194], +[1, -1, 0, 0, 193], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 164], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 161], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 152], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 145], +[1, 0, 0, 1, 144], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 137], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 133], +[1, 0, 1, 0, 132], +[1, 1, 0, 0, 131], +[1, 1, 0, 0, 130], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 100], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 98], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 96], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 88], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 82], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 74], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 72], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 70], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 67], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 65], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 56], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 52], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 44], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 40], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 38], +[1, 0, -1, 0, 37], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 33], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 28], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 26], +[1, 0, 0, -1, 25], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 20], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 18], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 9], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 6], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0] +] +tet_table = [ +[-1, -1, -1, -1, -1, -1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, -1], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, -1], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, -1, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, -1, 2, 4, 4, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, 5, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, -1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[-1, 1, 1, 4, 4, 1], +[0, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[8, 8, 8, 8, 8, 8], +[1, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 4, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 5, 5, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[6, 6, 6, 6, 6, 6], +[6, -1, 0, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 4, -1, 6, 4, 6], +[6, 4, 0, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 2, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 1, 1, 6, -1, 6], +[6, 1, 1, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 4], +[2, 2, 2, 2, 2, 2], +[6, 1, 1, 6, 4, 6], +[6, 1, 1, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 5, 0, 5, 0, 5], +[5, 5, 5, 5, 5, 5], +[5, 5, 5, 5, 5, 5], +[0, 5, 0, 5, 0, 5], +[-1, 5, 0, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[4, 5, -1, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[4, 5, 0, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 6, 6, 6, 6, 6], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, -1, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[2, 5, 2, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 4], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 6, 2, 6, 6, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 1, 4, 1], +[0, 1, 1, 1, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 0, 0, 6, 0, 6], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[5, 5, 5, 5, 5, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 4, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[4, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[8, 8, 8, 8, 8, 8], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 1, 1, 4, 4, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 4, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[12, 12, 12, 12, 12, 12] +] diff --git a/slrm/models/lrm.py b/slrm/models/lrm.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0fc1b80766fdecb024c5d5abb4ff2328c7188f --- /dev/null +++ b/slrm/models/lrm.py @@ -0,0 +1,238 @@ +# Copyright (c) 2023, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +import mcubes +import nvdiffrast.torch as dr +import loratorch as lora +from einops import rearrange, repeat + +from .encoder.dino_wrapper import DinoWrapper +from .decoder.transformer import TriplaneTransformer +from .renderer.synthesizer import TriplaneSynthesizer +from ..utils.mesh_util import xatlas_uvmap + + +class NeRFSLRM(nn.Module): + """ + Full model of the large reconstruction model. + """ + def __init__( + self, + encoder_freeze: bool = False, + encoder_model_name: str = 'facebook/dino-vitb16', + encoder_feat_dim: int = 768, + transformer_dim: int = 1024, + transformer_layers: int = 16, + transformer_heads: int = 16, + triplane_low_res: int = 32, + triplane_high_res: int = 64, + triplane_dim: int = 80, + rendering_samples_per_ray: int = 128, + is_ortho: bool = False, + lora_rank: int = 0, + ): + super().__init__() + + # modules + self.encoder = DinoWrapper( + model_name=encoder_model_name, + freeze=encoder_freeze, + ) + + self.transformer = TriplaneTransformer( + inner_dim=transformer_dim, + num_layers=transformer_layers, + num_heads=transformer_heads, + image_feat_dim=encoder_feat_dim, + triplane_low_res=triplane_low_res, + triplane_high_res=triplane_high_res, + triplane_dim=triplane_dim, + lora_rank=lora_rank, + ) + + if lora_rank > 0: + lora.mark_only_lora_as_trainable(self.transformer) + self.transformer.pos_embed.requires_grad = True + self.transformer.deconv.requires_grad = True + + self.synthesizer = TriplaneSynthesizer( + triplane_dim=triplane_dim, + samples_per_ray=rendering_samples_per_ray, + is_ortho=is_ortho, + ) + + if lora_rank > 0: + self.freeze_modules(encoder=True, transformer=False, + synthesizer=False) + + def freeze_modules(self, encoder=False, transformer=False, + synthesizer=False): + """ + Freeze specified modules + """ + if encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + if transformer: + for param in self.transformer.parameters(): + param.requires_grad = False + if synthesizer: + for param in self.synthesizer.parameters(): + param.requires_grad = False + + def forward_planes(self, images, cameras): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + B = images.shape[0] + + # encode images + image_feats = self.encoder(images, cameras) + image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) + + # transformer generating planes + planes = self.transformer(image_feats) + + return planes + + def forward_synthesizer(self, planes, render_cameras, render_size: int): + render_results = self.synthesizer( + planes, + render_cameras, + render_size, + ) + return render_results + + def forward(self, images, cameras, render_cameras, render_size: int): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + # render_cameras: [B, M, D_cam_render] + # render_size: int + B, M = render_cameras.shape[:2] + + planes = self.forward_planes(images, cameras) + + # render target views + render_results = self.synthesizer(planes, render_cameras, render_size) + + return { + 'planes': planes, + **render_results, + } + + def get_texture_prediction(self, planes, tex_pos, hard_mask=None): + ''' + Predict Texture given triplanes + :param planes: the triplane feature map + :param tex_pos: Position we want to query the texture field + :param hard_mask: 2D silhoueete of the rendered image + ''' + tex_pos = torch.cat(tex_pos, dim=0) + if not hard_mask is None: + tex_pos = tex_pos * hard_mask.float() + batch_size = tex_pos.shape[0] + tex_pos = tex_pos.reshape(batch_size, -1, 3) + ################### + # We use mask to get the texture location (to save the memory) + if hard_mask is not None: + n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) + sample_tex_pose_list = [] + max_point = n_point_list.max() + expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 + for i in range(tex_pos.shape[0]): + tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) + if tex_pos_one_shape.shape[1] < max_point: + tex_pos_one_shape = torch.cat( + [tex_pos_one_shape, torch.zeros( + 1, max_point - tex_pos_one_shape.shape[1], 3, + device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) + sample_tex_pose_list.append(tex_pos_one_shape) + tex_pos = torch.cat(sample_tex_pose_list, dim=0) + + tex_feat = torch.utils.checkpoint.checkpoint( + self.synthesizer.forward_points, + planes, + tex_pos, + use_reentrant=False, + )['rgb'] + + if hard_mask is not None: + final_tex_feat = torch.zeros( + planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) + expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 + for i in range(planes.shape[0]): + final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) + tex_feat = final_tex_feat + + return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) + + def extract_mesh( + self, + planes: torch.Tensor, + mesh_resolution: int = 256, + mesh_threshold: int = 10.0, + use_texture_map: bool = False, + texture_resolution: int = 1024, + **kwargs, + ): + ''' + Extract a 3D mesh from triplane nerf. Only support batch_size 1. + :param planes: triplane features + :param mesh_resolution: marching cubes resolution + :param mesh_threshold: iso-surface threshold + :param use_texture_map: use texture map or vertex color + :param texture_resolution: the resolution of texture map + ''' + assert planes.shape[0] == 1 + device = planes.device + + grid_out = self.synthesizer.forward_grid( + planes=planes, + grid_size=mesh_resolution, + ) + + vertices, faces = mcubes.marching_cubes( + grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), + mesh_threshold, + ) + vertices = vertices / (mesh_resolution - 1) * 2 - 1 + + if not use_texture_map: + # query vertex colors + vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0) + vertices_colors = self.synthesizer.forward_points( + planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy() + vertices_colors = (vertices_colors * 255).astype(np.uint8) + + return vertices, faces, vertices_colors + + # use x-atlas to get uv mapping for the mesh + vertices = torch.tensor(vertices, dtype=torch.float32, device=device) + faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device) + + ctx = dr.RasterizeCudaContext(device=device) + uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( + ctx, vertices, faces, resolution=texture_resolution) + tex_hard_mask = tex_hard_mask.float() + + # query the texture field to get the RGB color for texture map + tex_feat = self.get_texture_prediction( + planes, [gb_pos], tex_hard_mask) + background_feature = torch.zeros_like(tex_feat) + img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) + texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) + + return vertices, faces, uvs, mesh_tex_idx, texture_map diff --git a/slrm/models/lrm_mesh.py b/slrm/models/lrm_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..71a7afe6474f4badb066951c3dd16b59b3f568d2 --- /dev/null +++ b/slrm/models/lrm_mesh.py @@ -0,0 +1,612 @@ +# Copyright (c) 2023, Tencent Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import nvdiffrast.torch as dr +import loratorch as lora +from einops import rearrange, repeat + +from .encoder.dino_wrapper import DinoWrapper +from .decoder.transformer import TriplaneTransformer +from .renderer.synthesizer_mesh import TriplaneSynthesizer +from .geometry.camera.perspective_camera import PerspectiveCamera, OrthogonalCamera +from .geometry.render.neural_render import NeuralRender +from .geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry +from ..utils.mesh_util import xatlas_uvmap + + +class MeshSLRM(nn.Module): + """ + Full model of the large reconstruction model. + """ + def __init__( + self, + encoder_freeze: bool = False, + encoder_model_name: str = 'facebook/dino-vitb16', + encoder_feat_dim: int = 768, + transformer_dim: int = 1024, + transformer_layers: int = 16, + transformer_heads: int = 16, + triplane_low_res: int = 32, + triplane_high_res: int = 64, + triplane_dim: int = 80, + rendering_samples_per_ray: int = 128, + grid_res_xy: int = 128, + grid_res_z: int = 128, + grid_scale_xy: float = 2.0, + grid_scale_z: float = 2.0, + is_ortho: bool = False, + lora_rank: int = 0, + ): + super().__init__() + + # attributes + self.grid_res_xy = grid_res_xy + self.grid_res_z = grid_res_z + self.grid_scale_xy = grid_scale_xy + self.grid_scale_z = grid_scale_z + self.deformation_multiplier = 4.0 + + # modules + self.encoder = DinoWrapper( + model_name=encoder_model_name, + freeze=encoder_freeze, + ) + + self.transformer = TriplaneTransformer( + inner_dim=transformer_dim, + num_layers=transformer_layers, + num_heads=transformer_heads, + image_feat_dim=encoder_feat_dim, + triplane_low_res=triplane_low_res, + triplane_high_res=triplane_high_res, + triplane_dim=triplane_dim, + lora_rank=lora_rank, + ) + + if lora_rank > 0: + lora.mark_only_lora_as_trainable(self.transformer) + self.transformer.pos_embed.requires_grad = True + self.transformer.deconv.requires_grad = True + + self.synthesizer = TriplaneSynthesizer( + triplane_dim=triplane_dim, + samples_per_ray=rendering_samples_per_ray, + ) + + self.is_ortho = is_ortho + + if lora_rank > 0: + self.freeze_modules(encoder=True, transformer=False, + synthesizer=False) + + def freeze_modules(self, encoder=False, transformer=False, + synthesizer=False): + """ + Freeze specified modules + """ + if encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + if transformer: + for param in self.transformer.parameters(): + param.requires_grad = False + if synthesizer: + for param in self.synthesizer.parameters(): + param.requires_grad = False + + def init_flexicubes_geometry(self, device, fovy=50.0, is_ortho=False): + if not is_ortho: + camera = PerspectiveCamera(fovy=fovy, device=device) + else: + camera = OrthogonalCamera(device=device) + + with torch.cuda.amp.autocast(enabled=False): + renderer = NeuralRender(device, camera_model=camera) + self.geometry = FlexiCubesGeometry( + grid_res_xy=self.grid_res_xy, + grid_res_z=self.grid_res_z, + scale_xy=self.grid_scale_xy, + scale_z=self.grid_scale_z, + renderer=renderer, + render_type='neural_render', + device=device, + ) + + def forward_planes(self, images, cameras): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + B = images.shape[0] + + # encode images + image_feats = self.encoder(images, cameras) + image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) + + # decode triplanes + planes = self.transformer(image_feats) + + return planes + + def get_sdf_deformation_prediction(self, planes, levels=None): + ''' + Predict SDF and deformation for tetrahedron vertices + :param planes: triplane feature map for the geometry + ''' + init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1) + + # Step 1: predict the SDF and deformation + sdf, deformation, weight, semantics = torch.utils.checkpoint.checkpoint( + self.synthesizer.get_geometry_prediction, + planes, + init_position, + self.geometry.indices, + use_reentrant=False, + ) + + new_sdf = torch.zeros_like(sdf) + + for i_batch in range(sdf.shape[0]): + if levels[i_batch] == 0: # preserve all + new_sdf[i_batch] = sdf[i_batch] + elif levels[i_batch] == 1: # discard hair + new_sdf[i_batch] = torch.maximum( + sdf[i_batch], + semantics[i_batch][:, 0:1] - torch.max(semantics[i_batch][:, 1:], dim=-1, keepdim=True).values + ) + elif levels[i_batch] == 2: # discard hair and cloth + new_sdf[i_batch] = torch.maximum( + sdf[i_batch], + torch.maximum(semantics[i_batch][:, 0:1], semantics[i_batch][:, 3:4]) - \ + torch.maximum(semantics[i_batch][:, 1:2], semantics[i_batch][:, 2:3]) + ) + elif levels[i_batch] == 3: # only cloth + cloth_mask = torch.max(semantics[i_batch], dim=-1, keepdim=True).indices == 3 + # max pooling to get the cloth mask 3x3x3 + cloth_mask_nxnxn = cloth_mask.reshape((self.grid_res_xy + 1, self.grid_res_xy + 1, self.grid_res_z + 1)) + + kernel = torch.zeros(3, 3, 3, device=cloth_mask.device, dtype=cloth_mask.dtype) + kernel[1, 1, 0] = 1 + kernel[1, 1, 2] = 1 + kernel[1, 0, 1] = 1 + kernel[1, 2, 1] = 1 + kernel[0, 1, 1] = 1 + kernel[2, 1, 1] = 1 + kernel[1, 1, 1] = 1 + kernel = kernel.unsqueeze(0).unsqueeze(0).float() + cloth_mask_nxnxn = torch.nn.functional.conv3d( + cloth_mask_nxnxn.unsqueeze(0).unsqueeze(0).float(), + kernel, padding=1 + ).reshape(-1) + cloth_mask = cloth_mask_nxnxn > 0.5 + + new_sdf[i_batch] = torch.maximum( + sdf[i_batch], + torch.max(semantics[i_batch][:, 0:3], dim=-1, keepdim=True).values - semantics[i_batch][:, 3:4] + ) + new_sdf[i_batch][cloth_mask > 0.5] = sdf[i_batch][cloth_mask > 0.5] + + elif levels[i_batch] == 4: # only hair + hair_mask = torch.max(semantics[i_batch], dim=-1, keepdim=True).indices == 0 + # max pooling to get the hair mask 3x3x3 + hair_mask_nxnxn = hair_mask.reshape((self.grid_res_xy + 1, self.grid_res_xy + 1, self.grid_res_z + 1)) + + kernel = torch.zeros(3, 3, 3, device=hair_mask.device, dtype=hair_mask.dtype) + kernel[1, 1, 0] = 1 + kernel[1, 1, 2] = 1 + kernel[1, 0, 1] = 1 + kernel[1, 2, 1] = 1 + kernel[0, 1, 1] = 1 + kernel[2, 1, 1] = 1 + kernel[1, 1, 1] = 1 + kernel = kernel.unsqueeze(0).unsqueeze(0).float() + hair_mask_nxnxn = torch.nn.functional.conv3d( + hair_mask_nxnxn.unsqueeze(0).unsqueeze(0).float(), + kernel, padding=1 + ).reshape(-1) + hair_mask = hair_mask_nxnxn > 0.5 + + new_sdf[i_batch] = torch.maximum( + sdf[i_batch], + torch.max(semantics[i_batch][:, 1:4], dim=-1, keepdim=True).values - semantics[i_batch][:, 0:1] + ) + new_sdf[i_batch][hair_mask > 0.5] = sdf[i_batch][hair_mask > 0.5] + + sdf = new_sdf + + # Step 2: Normalize the deformation to avoid the flipped triangles. + deformation = 1.0 / (self.grid_res_z * self.deformation_multiplier) * torch.tanh(deformation) + sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=sdf.dtype) + + #### + # Step 3: Fix some sdf if we observe empty shape (full positive or full negative) + sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res_xy + 1, self.grid_res_xy + 1, self.grid_res_z + 1)) + sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1) + pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1) + neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1) + zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0) + if torch.sum(zero_surface).item() > 0: + update_sdf = torch.zeros_like(sdf[0:1]) + max_sdf = sdf.max() + min_sdf = sdf.min() + update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero + update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero + new_sdf = torch.zeros_like(sdf) + for i_batch in range(zero_surface.shape[0]): + if zero_surface[i_batch]: + new_sdf[i_batch:i_batch + 1] += update_sdf + update_mask = (new_sdf == 0).to(sdf.dtype) + # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative) + sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1) + sdf_reg_loss = sdf_reg_loss * zero_surface.to(sdf_reg_loss.dtype) + sdf = sdf * update_mask + new_sdf * (1 - update_mask) + + # Step 4: Here we remove the gradient for the bad sdf (full positive or full negative) + final_sdf = [] + final_def = [] + for i_batch in range(zero_surface.shape[0]): + if zero_surface[i_batch]: + final_sdf.append(sdf[i_batch: i_batch + 1].detach()) + final_def.append(deformation[i_batch: i_batch + 1].detach()) + else: + final_sdf.append(sdf[i_batch: i_batch + 1]) + final_def.append(deformation[i_batch: i_batch + 1]) + sdf = torch.cat(final_sdf, dim=0) + deformation = torch.cat(final_def, dim=0) + return sdf, deformation, sdf_reg_loss, weight + + def get_geometry_prediction(self, planes=None, levels=None): + ''' + Function to generate mesh with give triplanes + :param planes: triplane features + ''' + # Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid. + sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes, levels=levels) + v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation + tets = self.geometry.indices + n_batch = planes.shape[0] + v_list = [] + f_list = [] + flexicubes_surface_reg_list = [] + + # Step 2: Using marching tet to obtain the mesh + for i_batch in range(n_batch): + verts, faces, flexicubes_surface_reg = self.geometry.get_mesh( + v_deformed[i_batch], + sdf[i_batch].squeeze(dim=-1), + with_uv=False, + indices=tets, + weight_n=weight[i_batch].squeeze(dim=-1), + is_training=self.training, + ) + flexicubes_surface_reg_list.append(flexicubes_surface_reg) + v_list.append(verts) + f_list.append(faces) + + flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean() + flexicubes_weight_reg = (weight ** 2).mean() + + return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg) + + def get_texture_prediction(self, planes, tex_pos, hard_mask=None, levels=None): + ''' + Predict Texture given triplanes + :param planes: the triplane feature map + :param tex_pos: Position we want to query the texture field + :param hard_mask: 2D silhoueete of the rendered image + ''' + tex_pos = torch.cat(tex_pos, dim=0) + if not hard_mask is None: + tex_pos = tex_pos * hard_mask.to(tex_pos.dtype) + batch_size = tex_pos.shape[0] + tex_pos = tex_pos.reshape(batch_size, -1, 3) + ################### + # We use mask to get the texture location (to save the memory) + if hard_mask is not None: + n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) + sample_tex_pose_list = [] + max_point = n_point_list.max() + expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 + for i in range(tex_pos.shape[0]): + tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) + if tex_pos_one_shape.shape[1] < max_point: + tex_pos_one_shape = torch.cat( + [tex_pos_one_shape, torch.zeros( + 1, max_point - tex_pos_one_shape.shape[1], 3, + device=tex_pos_one_shape.device, dtype=tex_pos.dtype)], dim=1) + sample_tex_pose_list.append(tex_pos_one_shape) + tex_pos = torch.cat(sample_tex_pose_list, dim=0) + + tex_feat, semantic_feat = torch.utils.checkpoint.checkpoint( + self.synthesizer.get_texture_prediction, + planes, + tex_pos, + use_reentrant=False, + ) + + for idx_batch in range(semantic_feat.shape[0]): + if levels[idx_batch] == 0: + pass + elif levels[idx_batch] == 1: + semantic_feat[idx_batch, ..., 0:1] = 0 + semantic_feat[idx_batch, ..., 1:] = (semantic_feat[idx_batch, ..., 1:] + 1e-6) / \ + (semantic_feat[idx_batch, ..., 1:].sum(dim=-1, keepdim=True) + 1e-6) + elif levels[idx_batch] == 2: + semantic_feat[idx_batch, ..., 0:1] = 0 + semantic_feat[idx_batch, ..., 3:4] = 0 + semantic_feat[idx_batch, ..., 1:3] = (semantic_feat[idx_batch, ..., 1:3] + 1e-6) / \ + (semantic_feat[idx_batch, ..., 1:3].sum(dim=-1, keepdim=True) + 1e-6) + elif levels[idx_batch] == 3: + pass + elif levels[idx_batch] == 4: + pass + else: + raise ValueError(f"Invalid level {levels[idx_batch]}") + + if hard_mask is not None: + final_tex_feat = torch.zeros( + planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device, dtype=tex_feat.dtype) + expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 + for i in range(planes.shape[0]): + final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) + tex_feat = final_tex_feat + + final_semantic_feat = torch.zeros( + planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], semantic_feat.shape[-1], device=semantic_feat.device, dtype=semantic_feat.dtype) + expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_semantic_feat.shape[-1]) > 0.5 + for i in range(planes.shape[0]): + final_semantic_feat[i][expanded_hard_mask[i]] = semantic_feat[i][:n_point_list[i]].reshape(-1) + semantic_feat = final_semantic_feat + + + return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]), \ + semantic_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], semantic_feat.shape[-1]) + + def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256, dtype=torch.float32): + ''' + Function to render a generated mesh with nvdiffrast + :param mesh_v: List of vertices for the mesh + :param mesh_f: List of faces for the mesh + :param cam_mv: 4x4 rotation matrix + :return: + ''' + return_value_list = [] + for i_mesh in range(len(mesh_v)): + return_value = self.geometry.render_mesh( + mesh_v[i_mesh], + mesh_f[i_mesh].int(), + cam_mv[i_mesh], + resolution=render_size, + hierarchical_mask=False, + dtype=dtype + ) + return_value_list.append(return_value) + + return_keys = return_value_list[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in return_value_list] + return_value[k] = value + + mask = torch.cat(return_value['mask'], dim=0) + hard_mask = torch.cat(return_value['hard_mask'], dim=0) + tex_pos = return_value['tex_pos'] + depth = torch.cat(return_value['depth'], dim=0) + normal = torch.cat(return_value['normal'], dim=0) + return mask, hard_mask, tex_pos, depth, normal + + def forward_geometry(self, planes, render_cameras, render_size=256, levels=None): + ''' + Main function of our Generator. It first generate 3D mesh, then render it into 2D image + with given `render_cameras`. + :param planes: triplane features + :param render_cameras: cameras to render generated 3D shape + ''' + B, NV = render_cameras.shape[:2] + + # Generate 3D mesh first + mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes, levels=levels) + + # Render the mesh into 2D image (get 3d position of each image plane) + cam_mv = render_cameras + run_n_view = cam_mv.shape[1] + antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size, dtype=planes.dtype) + + tex_hard_mask = hard_mask + tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos] + tex_hard_mask = torch.cat( + [torch.cat( + [tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1] + for i_view in range(run_n_view)], dim=2) + for i in range(planes.shape[0])], dim=0) + + # Querying the texture field to predict the texture feature for each pixel on the image + tex_feat, semantic_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask, levels=levels) + background_feature = torch.ones_like(tex_feat) # white background + + # Merge them together + img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask) + semantic_feat = semantic_feat * tex_hard_mask + + # We should split it back to the original image shape + img_feat = torch.cat( + [torch.cat( + [img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)] + for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0) + + semantic_feat = torch.cat( + [torch.cat( + [semantic_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)] + for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0) + + img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) + semantic = semantic_feat.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive + normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + + out = { + 'img': img, + 'semantic': semantic, + 'mask': antilias_mask, + 'depth': depth, + 'normal': normal, + 'sdf': sdf, + 'mesh_v': mesh_v, + 'mesh_f': mesh_f, + 'sdf_reg_loss': sdf_reg_loss, + } + return out + + def forward_geometry_separate(self, planes, render_cameras, render_size=256, levels=None): + ''' + Main function of our Generator. It first generate 3D mesh, then render it into 2D image + with given `render_cameras`. + :param planes: triplane features + :param render_cameras: cameras to render generated 3D shape + ''' + B, NV = render_cameras.shape[:2] + + mesh_vs, mesh_fs, sdfs, deformations, v_deformeds = [], [], [], [], [] + + # Generate 3D mesh first + for _ in [0, 3, 4, 2]: + mesh_v, mesh_f, sdf, deformation, v_deformed, _ = self.get_geometry_prediction(planes, levels=torch.tensor([_]).to(planes.device)) + mesh_vs.append(mesh_v) + mesh_fs.append(mesh_f) + sdfs.append(sdf) + deformations.append(deformation) + v_deformeds.append(v_deformed) + + imgs, semantics, masks, depths, normals = [], [], [], [], [] + + # Render the mesh into 2D image (get 3d position of each image plane) + cam_mv = render_cameras + run_n_view = cam_mv.shape[1] + + for mesh_v, mesh_f, sdf, deformation, v_deformed in zip(mesh_vs, mesh_fs, sdfs, deformations, v_deformeds): + antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size) + tex_hard_mask = hard_mask + tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos] + tex_hard_mask = torch.cat( + [torch.cat( + [tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1] + for i_view in range(run_n_view)], dim=2) + for i in range(planes.shape[0])], dim=0) + + # Querying the texture field to predict the texture feature for each pixel on the image + tex_feat, semantic_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask, levels=levels) + background_feature = torch.ones_like(tex_feat) # white background + + # Merge them together + img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask) + semantic_feat = semantic_feat * tex_hard_mask + + # We should split it back to the original image shape + img_feat = torch.cat( + [torch.cat( + [img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)] + for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0) + + semantic_feat = torch.cat( + [torch.cat( + [semantic_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)] + for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0) + + img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) + semantic = semantic_feat.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive + normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + + imgs.append(img) + semantics.append(semantic) + masks.append(antilias_mask) + depths.append(depth) + normals.append(normal) + + + out = { + 'imgs': imgs, + 'semantics': semantics, + 'masks': masks, + 'depths': depths, + 'normals': normals, + 'sdfs': sdfs, + 'mesh_vs': mesh_vs, + 'mesh_fs': mesh_fs, + } + return out + + def forward(self, images, cameras, render_cameras, render_size: int): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + # render_cameras: [B, M, D_cam_render] + # render_size: int + B, M = render_cameras.shape[:2] + + planes = self.forward_planes(images, cameras) + out = self.forward_geometry(planes, render_cameras, render_size=render_size) + + return { + 'planes': planes, + **out + } + + def extract_mesh( + self, + planes: torch.Tensor, + use_texture_map: bool = False, + texture_resolution: int = 1024, + levels=None, + **kwargs, + ): + ''' + Extract a 3D mesh from FlexiCubes. Only support batch_size 1. + :param planes: triplane features + :param use_texture_map: use texture map or vertex color + :param texture_resolution: the resolution of texure map + ''' + assert planes.shape[0] == 1 + device = planes.device + + # predict geometry first + mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes, levels=levels) + vertices, faces = mesh_v[0], mesh_f[0] + + if not use_texture_map: + # query vertex colors + vertices_tensor = vertices.unsqueeze(0) + vertices_colors = self.synthesizer.get_texture_prediction( + planes, vertices_tensor)[0].clamp(0, 1).squeeze(0).cpu().numpy() + vertices_colors = (vertices_colors * 255).astype(np.uint8) + + return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors + + # use x-atlas to get uv mapping for the mesh + ctx = dr.RasterizeCudaContext(device=device) + uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( + self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution) + tex_hard_mask = tex_hard_mask.to(planes.dtype) + + # query the texture field to get the RGB color for texture map + tex_feat, _ = self.get_texture_prediction( + planes, [gb_pos], tex_hard_mask, levels=levels) + background_feature = torch.zeros_like(tex_feat) + img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) + texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) + + return vertices, faces, uvs, mesh_tex_idx, texture_map \ No newline at end of file diff --git a/slrm/models/renderer/__init__.py b/slrm/models/renderer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34 --- /dev/null +++ b/slrm/models/renderer/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. diff --git a/slrm/models/renderer/synthesizer.py b/slrm/models/renderer/synthesizer.py new file mode 100755 index 0000000000000000000000000000000000000000..5255ce48c565ca59026abe2032061652c8c49be4 --- /dev/null +++ b/slrm/models/renderer/synthesizer.py @@ -0,0 +1,237 @@ +# ORIGINAL LICENSE +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. +# +# Modified by Yuze He +# The modifications are subject to the same license as the original. + + +import itertools +import torch +import torch.nn as nn + +from .utils.renderer import ImportanceRenderer +from .utils.ray_sampler import OrthoRaySampler, RaySampler + + +class OSGDecoder(nn.Module): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + def __init__(self, n_features: int, + hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): + super().__init__() + self.net = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1 + 3), + ) + self.semantic_net = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 4), + ) + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def forward(self, sampled_features, ray_directions): + # Aggregate features by mean + # sampled_features = sampled_features.mean(1) + # Aggregate features by concatenation + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + x = sampled_features + + N, M, C = x.shape + x = x.contiguous().view(N*M, C) + + hidden = self.net[:-1](x) + semantic = self.semantic_net(x) + x = self.net[-1](hidden) + + x = x.view(N, M, -1) + semantic = semantic.view(N, M, -1) + rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + semantic = torch.softmax(semantic, dim=-1) + sigma = x[..., 0:1] + + return {'rgb': rgb, 'sigma': sigma, 'semantic': semantic} + + +class TriplaneSynthesizer(nn.Module): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + 'ray_start': 'auto', + 'ray_end': 'auto', + 'box_warp': 2., + 'white_back': True, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'sampler_bbox_min': -1., + 'sampler_bbox_max': 1., + } + + def __init__(self, triplane_dim: int, samples_per_ray: int, is_ortho: bool = False): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + 'depth_resolution': samples_per_ray // 2, + 'depth_resolution_importance': samples_per_ray // 2, + } + + # renderings + self.renderer = ImportanceRenderer() + + # ray sampler + if is_ortho: + self.ray_sampler = OrthoRaySampler() + self.is_ortho = True + else: + self.ray_sampler = RaySampler() + self.is_ortho = False + + # modules + self.decoder = OSGDecoder(n_features=triplane_dim) + + def forward(self, planes, cameras, render_size=128, crop_params=None, levels=None): + # planes: (N, 3, D', H', W') + # cameras: (N, M, D_cam) + # render_size: int + assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras" + N, M = cameras.shape[:2] + + cam2world_matrix = cameras[..., :16].view(N, M, 4, 4) + intrinsics = cameras[..., 16:25].view(N, M, 3, 3) + + if self.is_ortho: + ortho_scale = cameras[..., 25].view(N, M) + else: + ortho_scale = None + + self.rendering_kwargs['levels'] = levels.repeat_interleave(M, dim=0) + + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler( + cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4), + intrinsics=intrinsics.reshape(-1, 3, 3), + ortho_scale=ortho_scale, + render_size=render_size, + ) + assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins" + assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional" + + # Crop rays if crop_params is available + if crop_params is not None: + ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3) + ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3) + i, j, h, w = crop_params + ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) + ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) + + # Perform volume rendering + rgb_samples, depth_samples, semantic_samples, weights_samples = self.renderer( + planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs, + ) + + # Reshape into 'raw' neural-rendered image + if crop_params is not None: + Himg, Wimg = crop_params[2:] + else: + Himg = Wimg = render_size + rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous() + depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + semantic_samples = semantic_samples.permute(0, 2, 1).reshape(N, M, semantic_samples.shape[-1], Himg, Wimg) + weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + + out = { + 'images_rgb': rgb_images, + 'images_depth': depth_images, + 'images_weight': weight_images, + 'images_semantic': semantic_samples, + } + return out + + def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None): + # planes: (N, 3, D', H', W') + # grid_size: int + # aabb: (N, 2, 3) + if aabb is None: + aabb = torch.tensor([ + [self.rendering_kwargs['sampler_bbox_min']] * 3, + [self.rendering_kwargs['sampler_bbox_max']] * 3, + ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) + assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" + N = planes.shape[0] + + # create grid points for triplane query + grid_points = [] + for i in range(N): + grid_points.append(torch.stack(torch.meshgrid( + torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), + indexing='ij', + ), dim=-1).reshape(-1, 3)) + cube_grid = torch.stack(grid_points, dim=0).to(planes.device) + + features = self.forward_points(planes, cube_grid) + + # reshape into grid + features = { + k: v.reshape(N, grid_size, grid_size, grid_size, -1) + for k, v in features.items() + } + return features + + def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20): + # planes: (N, 3, D', H', W') + # points: (N, P, 3) + N, P = points.shape[:2] + + # query triplane in chunks + outs = [] + for i in range(0, points.shape[1], chunk_size): + chunk_points = points[:, i:i+chunk_size] + + # query triplane + chunk_out = self.renderer.run_model_activated( + planes=planes, + decoder=self.decoder, + sample_coordinates=chunk_points, + sample_directions=torch.zeros_like(chunk_points), + options=self.rendering_kwargs, + ) + outs.append(chunk_out) + + # concatenate the outputs + point_features = { + k: torch.cat([out[k] for out in outs], dim=1) + for k in outs[0].keys() + } + return point_features diff --git a/slrm/models/renderer/synthesizer_mesh.py b/slrm/models/renderer/synthesizer_mesh.py new file mode 100755 index 0000000000000000000000000000000000000000..48ee4a221eea3526413b468561b8c9da9a693a7f --- /dev/null +++ b/slrm/models/renderer/synthesizer_mesh.py @@ -0,0 +1,164 @@ +# ORIGINAL LICENSE +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. +# +# Modified by Yuze He +# The modifications are subject to the same license as the original. + +import itertools +import torch +import torch.nn as nn + +from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes + + +class OSGDecoder(nn.Module): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + def __init__(self, n_features: int, + hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): + super().__init__() + + self.net_sdf = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1), + ) + self.net_rgb = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 3), + ) + self.net_deformation = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 3), + ) + self.net_weight = nn.Sequential( + nn.Linear(8 * 3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 21), + ) + self.semantic_net = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 4), + ) + + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def get_geometry_prediction(self, sampled_features, flexicubes_indices): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + + sdf = self.net_sdf(sampled_features) + deformation = self.net_deformation(sampled_features) + + grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) + grid_features = grid_features.reshape( + sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) + weight = self.net_weight(grid_features) * 0.1 + + return sdf, deformation, weight + + def get_texture_prediction(self, sampled_features): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + + rgb = self.net_rgb(sampled_features) + rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + + return rgb + + def get_semantic_prediction(self, sampled_features): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + + semantic = self.semantic_net(sampled_features) + semantic = torch.softmax(semantic, dim=-1) + + return semantic + + +class TriplaneSynthesizer(nn.Module): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + 'ray_start': 'auto', + 'ray_end': 'auto', + 'box_warp': 2., + 'white_back': True, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'sampler_bbox_min': -1., + 'sampler_bbox_max': 1., + } + + def __init__(self, triplane_dim: int, samples_per_ray: int): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + 'depth_resolution': samples_per_ray // 2, + 'depth_resolution_importance': samples_per_ray // 2, + } + + # modules + self.plane_axes = generate_planes() + self.decoder = OSGDecoder(n_features=triplane_dim) + + def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes( + plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) + + sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) + semantic = self.decoder.get_semantic_prediction(sampled_features) + return sdf, deformation, weight, semantic + + def get_texture_prediction(self, planes, sample_coordinates): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes( + plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) + + rgb = self.decoder.get_texture_prediction(sampled_features) + semantic = self.decoder.get_semantic_prediction(sampled_features) + return rgb, semantic diff --git a/slrm/models/renderer/utils/__init__.py b/slrm/models/renderer/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34 --- /dev/null +++ b/slrm/models/renderer/utils/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. diff --git a/slrm/models/renderer/utils/math_utils.py b/slrm/models/renderer/utils/math_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..4cf9d2b811e0acbc7923bc9126e010b52cb1a8af --- /dev/null +++ b/slrm/models/renderer/utils/math_utils.py @@ -0,0 +1,118 @@ +# MIT License + +# Copyright (c) 2022 Petr Kellnhofer + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch + +def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = torch.matmul(vectors4, matrix.T) + return res + + +def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + +def torch_dot(x: torch.Tensor, y: torch.Tensor): + """ + Dot product of two tensors. + """ + return (x * y).sum(-1) + + +def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): + """ + Author: Petr Kellnhofer + Intersects rays with the [-1, 1] NDC volume. + Returns min and max distance of entry. + Returns -1 for no intersection. + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + o_shape = rays_o.shape + rays_o = rays_o.detach().reshape(-1, 3) + rays_d = rays_d.detach().reshape(-1, 3) + + + bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] + bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] + bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) + is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).long() + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tymin) + tmax = torch.min(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tzmin) + tmax = torch.min(tmax, tzmax) + + # Mark invalid. + tmin[torch.logical_not(is_valid)] = -1 + tmax[torch.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + + +def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out diff --git a/slrm/models/renderer/utils/ray_marcher.py b/slrm/models/renderer/utils/ray_marcher.py new file mode 100755 index 0000000000000000000000000000000000000000..0dbe78bad1ea4fc7ae5c9fbabd10ef77715a7a86 --- /dev/null +++ b/slrm/models/renderer/utils/ray_marcher.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. +# +# Modified by Yuze He +# The modifications are subject to the same license as the original. + + +""" +The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. +Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MipRayMarcher2(nn.Module): + def __init__(self, activation_factory): + super().__init__() + self.activation_factory = activation_factory + + def run_forward(self, colors, densities, depths, semantics, rendering_options, normals=None): + dtype = colors.dtype + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + semantics_mid = (semantics[:, :, :-1] + semantics[:, :, 1:]) / 2 + + # using factory mode for better usability + densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype) + + density_delta = densities_mid * deltas + + alpha = 1 - torch.exp(-density_delta).to(dtype) + + for bid in range(len(rendering_options['levels'])): + if rendering_options['levels'][bid] == 0: + pass + elif rendering_options['levels'][bid] == 1: + alpha[bid] = alpha[bid] * (1 - semantics_mid[bid, ..., 0:1]) + semantics_mid[bid, ..., 0:1] = 0 # remove hair + semantics_mid[bid, ..., 1:] = (semantics_mid[bid, ..., 1:] + 1e-6) / (torch.sum(semantics_mid[bid, ..., 1:], dim=-1, keepdim=True) + 1e-6) + elif rendering_options['levels'][bid] == 2: + alpha[bid] = alpha[bid] * (1 - semantics_mid[bid, ..., 0:1] - semantics_mid[bid, ..., 3:4]) + semantics_mid[bid, ..., 0:1] = 0 # remove hair + semantics_mid[bid, ..., 3:4] = 0 # remove cloth + semantics_mid[bid, ..., 1:3] = (semantics_mid[bid, ..., 1:3] + 1e-6) / (torch.sum(semantics_mid[bid, ..., 1:3], dim=-1, keepdim=True) + 1e-6) + else: + raise NotImplementedError("Only 0, 1, 2 levels are supported") + + alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + weights = weights.to(dtype) + + composite_rgb = torch.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + composite_depth = torch.sum(weights * depths_mid, -2) + composite_semantics = torch.sum(weights * semantics_mid, -2) + + # clip the composite to min/max range of depths + composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype) + composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) + + if rendering_options.get('white_back', False): + composite_rgb = composite_rgb + 1 - weight_total + + # rendered value scale is 0-1, comment out original mipnerf scaling + # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) + + return composite_rgb, composite_depth, composite_semantics, weights + + + def forward(self, colors, densities, depths, semantics, rendering_options, normals=None): + if normals is not None: + raise NotImplementedError("Normals are not supported in the ray marcher yet.") + composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals) + return composite_rgb, composite_depth, composite_normals, weights + + composite_rgb, composite_depth, composite_semantic, weights = self.run_forward(colors, densities, depths, semantics, rendering_options) + return composite_rgb, composite_depth, composite_semantic, weights diff --git a/slrm/models/renderer/utils/ray_sampler.py b/slrm/models/renderer/utils/ray_sampler.py new file mode 100755 index 0000000000000000000000000000000000000000..372cace75423124fdade7ef09376c3d7c6b449b7 --- /dev/null +++ b/slrm/models/renderer/utils/ray_sampler.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. +# +# Modified by Yuze He +# The modifications are subject to the same license as the original. + + +""" +The ray sampler is a module that takes in camera matrices and resolution and batches of rays. +Expects cam2world matrices that use the OpenCV camera coordinate system conventions. +""" + +import torch + +class RaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, intrinsics, ortho_scale, render_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + ortho_scale: NOT USED + render_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + + dtype = cam2world_matrix.dtype + device = cam2world_matrix.device + N, M = cam2world_matrix.shape[0], render_size**2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + uv = torch.stack(torch.meshgrid( + torch.arange(render_size, dtype=dtype, device=device), + torch.arange(render_size, dtype=dtype, device=device), + indexing='ij', + )) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) + z_cam = torch.ones((N, M), dtype=dtype, device=device) + + x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype) + + _opencv2blender = torch.tensor([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype) + + ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) + + return ray_origins, ray_dirs + + +class OrthoRaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, intrinsics, ortho_scale, render_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: NOT USED + ortho_scale: (N) + render_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 3) + """ + + N, M = cam2world_matrix.shape[0], render_size**2 + + uv = torch.stack(torch.meshgrid( + torch.arange(render_size).to(cam2world_matrix), + torch.arange(render_size).to(cam2world_matrix), + indexing='ij', + )) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) + z_cam = torch.zeros((N, M)).to(cam2world_matrix) + + x_lift = (x_cam - 0.5) * ortho_scale + y_lift = (y_cam - 0.5) * ortho_scale + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) + + _opencv2blender = torch.tensor([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ]).to(cam2world_matrix).unsqueeze(0).repeat(N, 1, 1) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs_cam = torch.stack([ + torch.zeros((N, M), device=cam2world_matrix.device, dtype=cam2world_matrix.dtype), + torch.zeros((N, M), device=cam2world_matrix.device, dtype=cam2world_matrix.dtype), + torch.ones((N, M), device=cam2world_matrix.device, dtype=cam2world_matrix.dtype), + ], dim=-1) + ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1) + + return ray_origins, ray_dirs diff --git a/slrm/models/renderer/utils/renderer.py b/slrm/models/renderer/utils/renderer.py new file mode 100755 index 0000000000000000000000000000000000000000..a20aa6300eb062a79efaf269b89b6b193724faa0 --- /dev/null +++ b/slrm/models/renderer/utils/renderer.py @@ -0,0 +1,331 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. +# +# Modified by Yuze He +# The modifications are subject to the same license as the original. + + +""" +The renderer is a module that takes in rays, decides where to sample along each +ray, and computes pixel colors using the volume rendering equation. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ray_marcher import MipRayMarcher2 +from . import math_utils + + +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + + Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]], dtype=torch.float32) + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) + inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) + projections = torch.bmm(coordinates, inv_planes) + return projections[..., :2] + +def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): + assert padding_mode == 'zeros' + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N*n_planes, C, H, W) + dtype = plane_features.dtype + + coordinates = (2/box_warp) * coordinates # add specific box bounds + + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) + output_features = torch.nn.functional.grid_sample( + plane_features, + projected_coordinates.to(dtype), + mode=mode, + padding_mode=padding_mode, + align_corners=False, + ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + return output_features + +def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample( + grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', + padding_mode='zeros', + align_corners=False, + ) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) + return sampled_features + +class ImportanceRenderer(torch.nn.Module): + """ + Modified original version to filter out-of-box samples as TensoRF does. + + Reference: + TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277 + """ + def __init__(self): + super().__init__() + self.activation_factory = self._build_activation_factory() + self.ray_marcher = MipRayMarcher2(self.activation_factory) + self.plane_axes = generate_planes() + + def _build_activation_factory(self): + def activation_factory(options: dict): + if options['clamp_mode'] == 'softplus': + return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better + else: + assert False, "Renderer only supports `clamp_mode`=`softplus`!" + return activation_factory + + def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor, + planes: torch.Tensor, decoder: nn.Module, rendering_options: dict): + """ + Additional filtering is applied to filter out-of-box samples. + Modifications made by Zexin He. + """ + + # context related variables + batch_size, num_rays, samples_per_ray, _ = depths.shape + device = depths.device + + # define sample points with depths + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) + sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + + # filter out-of-box samples + mask_inbox = \ + (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \ + (sample_coordinates <= rendering_options['sampler_bbox_max']) + mask_inbox = mask_inbox.all(-1) + + # forward model according to all samples + _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) + + # set out-of-box samples to zeros(rgb) & -inf(sigma) + SAFE_GUARD = 3 + DATA_TYPE = _out['sigma'].dtype + colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE) + semantics_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 4, device=device, dtype=DATA_TYPE) + densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD + colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox] + semantics_pass[mask_inbox] = _out['semantic'][mask_inbox] + + # reshape back + colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1]) + semantics_pass = semantics_pass.reshape(batch_size, num_rays, samples_per_ray, semantics_pass.shape[-1]) + densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1]) + + return colors_pass, densities_pass, semantics_pass + + def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options): + # self.plane_axes = self.plane_axes.to(ray_origins.device) + + if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': + ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + else: + # Create stratified depth samples + depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + + # Coarse Pass + colors_coarse, densities_coarse, semantics_coarse = self._forward_pass( + depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + # Fine Pass + N_importance = rendering_options['depth_resolution_importance'] + if N_importance > 0: + _, _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, semantics_coarse, rendering_options) + + depths_fine = self.sample_importance(depths_coarse, weights, N_importance) + + colors_fine, densities_fine, semantics_fine = self._forward_pass( + depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + all_depths, all_colors, all_densities, all_semantics = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, semantics_coarse, + depths_fine, colors_fine, densities_fine, semantics_fine) + + rgb_final, depth_final, semantic_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, all_semantics, rendering_options) + else: + rgb_final, depth_final, semantic_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, semantics_coarse, rendering_options) + + return rgb_final, depth_final, semantic_final, weights.sum(2) + + def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) + + out = decoder(sampled_features, sample_directions) + if options.get('density_noise', 0) > 0: + out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] + return out + + def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options): + out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options) + out['sigma'] = self.activation_factory(options)(out['sigma']) + return out + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + return all_depths, all_colors, all_densities + + def unify_samples(self, depths1, colors1, densities1, semantics1, depths2, colors2, densities2, semantics2, normals1=None, normals2=None): + all_depths = torch.cat([depths1, depths2], dim = -2) + all_colors = torch.cat([colors1, colors2], dim = -2) + all_densities = torch.cat([densities1, densities2], dim = -2) + all_semantics = torch.cat([semantics1, semantics2], dim = -2) + + if normals1 is not None and normals2 is not None: + all_normals = torch.cat([normals1, normals2], dim = -2) + else: + all_normals = None + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + all_semantics = torch.gather(all_semantics, -2, indices.expand(-1, -1, -1, all_semantics.shape[-1])) + + if all_normals is not None: + all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1])) + return all_depths, all_colors, all_normals, all_densities, all_semantics + + return all_depths, all_colors, all_densities, all_semantics + + def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if disparity_space_sampling: + depths_coarse = torch.linspace(0, + 1, + depth_resolution, + device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = 1/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) + else: + if type(ray_start) == torch.Tensor: + depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] + else: + depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start)/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + + return depths_coarse + + def sample_importance(self, z_vals, weights, N_importance): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp_min(inds-1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[...,1]-cdf_g[...,0] + denom[denom 0 and radius > 0 + + elevation = np.deg2rad(elevation) + + camera_positions = [] + for i in range(M): + azimuth = 2 * np.pi * i / M + x = radius * np.cos(elevation) * np.cos(azimuth) + y = radius * np.cos(elevation) * np.sin(azimuth) + z = radius * np.sin(elevation) + camera_positions.append([x, y, z]) + camera_positions = np.array(camera_positions) + camera_positions = torch.from_numpy(camera_positions).float() + extrinsics = center_looking_at_camera_pose(camera_positions) + return extrinsics + + +def FOV_to_intrinsics(fov, device='cpu'): + """ + Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. + Note the intrinsics are returned as normalized by image size, rather than in pixel units. + Assumes principal point is at image center. + """ + focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + return intrinsics + + +def get_era3d_input_cameras(batch_size=1, radius=4.0, fov=30.0): + """ + Get the input camera parameters. + """ + azimuths = np.array([0, 45, 90, 180, 270, 315]).astype(float) + elevations = np.array([0, 0, 0, 0, 0, 0]).astype(float) + + c2ws = spherical_camera_pose(azimuths, elevations, radius) + c2ws = c2ws.float().flatten(-2) + + Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2) + + extrinsics = c2ws[:, :12] + intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) + cameras = torch.cat([extrinsics, intrinsics], dim=-1) + + return cameras.unsqueeze(0).repeat(batch_size, 1, 1) + + +if __name__ == '__main__': + get_era3d_input_cameras() diff --git a/slrm/utils/infer_util.py b/slrm/utils/infer_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f2faf2bf3b12d4af7b33cb2292da2b5ed62eb52e --- /dev/null +++ b/slrm/utils/infer_util.py @@ -0,0 +1,97 @@ +import os +import imageio +import rembg +import torch +import numpy as np +import PIL.Image +from PIL import Image +from typing import Any + + +def remove_background(image: PIL.Image.Image, + rembg_session: Any = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + + +def resize_foreground( + image: PIL.Image.Image, + ratio: float, +) -> PIL.Image.Image: + image = np.array(image) + assert image.shape[-1] == 4 + alpha = np.where(image[..., 3] > 0) + y1, y2, x1, x2 = ( + alpha[0].min(), + alpha[0].max(), + alpha[1].min(), + alpha[1].max(), + ) + # crop the foreground + fg = image[y1:y2, x1:x2] + # pad to square + size = max(fg.shape[0], fg.shape[1]) + ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 + ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 + new_image = np.pad( + fg, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + + # compute padding according to the ratio + new_size = int(new_image.shape[0] / ratio) + # pad to size, double side + ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 + ph1, pw1 = new_size - size - ph0, new_size - size - pw0 + new_image = np.pad( + new_image, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + new_image = PIL.Image.fromarray(new_image) + return new_image + + +def images_to_video( + images: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + # images: (N, C, H, W) + video_dir = os.path.dirname(output_path) + video_name = os.path.basename(output_path) + os.makedirs(video_dir, exist_ok=True) + + frames = [] + for i in range(len(images)): + frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ + f"Frame shape mismatch: {frame.shape} vs {images.shape}" + assert frame.min() >= 0 and frame.max() <= 255, \ + f"Frame value out of range: {frame.min()} ~ {frame.max()}" + frames.append(frame) + imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) + + +def save_video( + frames: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + # images: (N, C, H, W) + frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] + writer = imageio.get_writer(output_path, fps=fps) + for frame in frames: + writer.append_data(frame) + writer.close() \ No newline at end of file diff --git a/slrm/utils/mesh_util.py b/slrm/utils/mesh_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec4663eeaa5c54209e08771969ec4f2a739c0b4 --- /dev/null +++ b/slrm/utils/mesh_util.py @@ -0,0 +1,181 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import xatlas +import trimesh +import cv2 +import numpy as np +import nvdiffrast.torch as dr +from PIL import Image + + +def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath): + + pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) + facenp_fx3 = facenp_fx3[:, [2, 1, 0]] + + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fpath, 'obj') + + +def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath): + + pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]]) + + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fpath, 'glb') + + +def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): + import os + fol, na = os.path.split(fname) + na, _ = os.path.splitext(na) + + matname = '%s/%s.mtl' % (fol, na) + fid = open(matname, 'w') + fid.write('newmtl material_0\n') + fid.write('Kd 1 1 1\n') + fid.write('Ka 0 0 0\n') + fid.write('Ks 0.4 0.4 0.4\n') + fid.write('Ns 10\n') + fid.write('illum 2\n') + fid.write('map_Kd %s.png\n' % na) + fid.close() + #### + + fid = open(fname, 'w') + fid.write('mtllib %s.mtl\n' % na) + + for pidx, p in enumerate(pointnp_px3): + pp = p + fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) + + for pidx, p in enumerate(tcoords_px2): + pp = p + fid.write('vt %f %f\n' % (pp[0], pp[1])) + + fid.write('usemtl material_0\n') + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + f2 = facetex_fx3[i] + 1 + fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + # save texture map + lo, hi = 0, 1 + img = np.asarray(texmap_hxwx3, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = img.clip(0, 255) + mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) + mask = (mask <= 3.0).astype(np.float32) + kernel = np.ones((3, 3), 'uint8') + dilate_img = cv2.dilate(img, kernel, iterations=1) + img = img * (1 - mask) + dilate_img * mask + img = img.clip(0, 255).astype(np.uint8) + Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png') + + +def loadobj(meshfile): + v = [] + f = [] + meshfp = open(meshfile, 'r') + for line in meshfp.readlines(): + data = line.strip().split(' ') + data = [da for da in data if len(da) > 0] + if len(data) != 4: + continue + if data[0] == 'v': + v.append([float(d) for d in data[1:]]) + if data[0] == 'f': + data = [da.split('/')[0] for da in data] + f.append([int(d) for d in data[1:]]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + return pointnp_px3, facenp_fx3 + + +def loadobjtex(meshfile): + v = [] + vt = [] + f = [] + ft = [] + meshfp = open(meshfile, 'r') + for line in meshfp.readlines(): + data = line.strip().split(' ') + data = [da for da in data if len(da) > 0] + if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)): + continue + if data[0] == 'v': + assert len(data) == 4 + + v.append([float(d) for d in data[1:]]) + if data[0] == 'vt': + if len(data) == 3 or len(data) == 4: + vt.append([float(d) for d in data[1:3]]) + if data[0] == 'f': + data = [da.split('/') for da in data] + if len(data) == 4: + f.append([int(d[0]) for d in data[1:]]) + ft.append([int(d[1]) for d in data[1:]]) + elif len(data) == 5: + idx1 = [1, 2, 3] + data1 = [data[i] for i in idx1] + f.append([int(d[0]) for d in data1]) + ft.append([int(d[1]) for d in data1]) + idx2 = [1, 3, 4] + data2 = [data[i] for i in idx2] + f.append([int(d[0]) for d in data2]) + ft.append([int(d[1]) for d in data2]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + ftnp_fx3 = np.array(ft, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + uvs = np.array(vt, dtype=np.float32) + return pointnp_px3, facenp_fx3, uvs, ftnp_fx3 + + +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): + vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + return uvs, mesh_tex_idx, gb_pos, mask diff --git a/slrm/utils/train_util.py b/slrm/utils/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..2e65421bffa8cc42c1517e86f2dfd8183caf52ab --- /dev/null +++ b/slrm/utils/train_util.py @@ -0,0 +1,26 @@ +import importlib + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls)