diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..f6b1f326ca4ab7cf0c8798856f8fe0020ff82d58 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index a8cc3865dcc72ed09a6108308354af98f9eeffeb..04ac80b975ace53c2738b1053f5d4d9cf54d303f 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,11 @@
---
-title: LVM
-emoji: 🔥
-colorFrom: yellow
-colorTo: gray
+title: VQLM Demo
+emoji: 🎨
+colorFrom: "yellow"
+colorTo: "blue"
sdk: gradio
-sdk_version: 4.36.1
+sdk_version: "4.29.0"
app_file: app.py
pinned: false
-license: apache-2.0
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d34083aa00b5137708967eba4245149d4545eb6
--- /dev/null
+++ b/app.py
@@ -0,0 +1,244 @@
+import gradio as gr
+import numpy as np
+import mlxu
+import os
+import re
+import torch
+
+from io import BytesIO
+from natsort import natsorted
+from PIL import Image
+
+from inference import LocalInferenceModel
+
+FLAGS, _ = mlxu.define_flags_with_default(
+ host='0.0.0.0',
+ port=5000,
+ dtype='float16',
+ checkpoint='Emma02/LVM_ckpts',
+ torch_devices='',
+ context_frames=16,
+)
+
+def natural_sort_key(s):
+ return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
+
+def load_example_image_groups(directory):
+ example_groups = {}
+ for subdir in os.listdir(directory):
+ subdir_path = os.path.join(directory, subdir)
+ if os.path.isdir(subdir_path):
+ example_groups[subdir] = []
+ images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+ images = natsorted(images, key=natural_sort_key)
+ for filename in images:
+ img = Image.open(os.path.join(subdir_path, filename))
+ example_groups[subdir].append(img)
+ return example_groups
+
+def main(_):
+ assert FLAGS.checkpoint != ''
+
+ model = LocalInferenceModel(
+ checkpoint=FLAGS.checkpoint,
+ torch_device=torch.device("cuda"),
+ dtype=FLAGS.dtype,
+ context_frames=FLAGS.context_frames,
+ use_lock=False,
+ )
+
+ checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1)
+ checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1)
+ checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32)
+
+ def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9):
+ assert len(input_images) > 0
+ input_images = [
+ np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0
+ for img in input_images
+ ]
+ input_images = np.stack(input_images, axis=0)
+ output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0]
+
+ generated_images = []
+ for candidate in output_images:
+ concatenated_image = []
+ for i, img in enumerate(candidate):
+ concatenated_image.append(img)
+ if i < len(candidate) - 1:
+ concatenated_image.append(checkerboard)
+ generated_images.append(
+ Image.fromarray(
+ (np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8)
+ )
+ )
+
+ return generated_images
+
+ with gr.Blocks(css="""
+ .small-button {
+ padding: 5px 10px;
+ min-width: 80px;
+ }
+ .large-gallery img {
+ width: 100%;
+ height: auto;
+ max-height: 150px;
+ }
+ """) as demo:
+ with gr.Column():
+ image_list = gr.State([])
+ gr.Markdown('# VQLM Demo')
+ gr.Markdown(f'Serving model: {FLAGS.checkpoint}')
+ gr.Markdown('## Inputs')
+ with gr.Row():
+ upload_drag = gr.File(
+ type='binary',
+ file_types=['image'],
+ file_count='multiple',
+ )
+ with gr.Column():
+ gen_length_slider = gr.Slider(
+ label='Generation length',
+ minimum=1,
+ maximum=32,
+ value=1,
+ step=1,
+ interactive=True,
+ )
+ n_candidates_slider = gr.Slider(
+ label='Number of candidates',
+ minimum=1,
+ maximum=10,
+ value=1,
+ step=1,
+ interactive=True,
+ )
+ temp_slider = gr.Slider(
+ label='Temperature',
+ minimum=0,
+ maximum=2.0,
+ value=1.0,
+ interactive=True,
+ )
+ top_p_slider = gr.Slider(
+ label='Top p',
+ minimum=0,
+ maximum=1.0,
+ value=0.9,
+ interactive=True,
+ )
+ clear_btn = gr.Button(
+ value='Clear',
+ elem_classes=['small-button'],
+ )
+ generate_btn = gr.Button(
+ value='Generate',
+ interactive=False,
+ elem_classes=['small-button'],
+ )
+ input_gallery = gr.Gallery(
+ columns=7,
+ rows=1,
+ object_fit='scale-down',
+ label="Input image sequence"
+ )
+ gr.Markdown('## Outputs')
+ output_gallery = gr.Gallery(
+ columns=4,
+ object_fit='scale-down',
+ label="Output image"
+ )
+
+ def upload_image_fn(files, images):
+ for file in files:
+ images.append(Image.open(BytesIO(file)))
+
+ return {
+ upload_drag: None,
+ image_list: images,
+ input_gallery: images,
+ generate_btn: gr.update(interactive=True),
+ }
+
+ def clear_fn():
+ return {
+ image_list: [],
+ input_gallery: [],
+ generate_btn: gr.update(interactive=False),
+ output_gallery: [],
+ }
+
+ def disable_generate_btn():
+ return {
+ generate_btn: gr.update(interactive=False),
+ }
+
+ def generate_fn(images, n_candidates, gen_length, temperature, top_p):
+ new_images = generate_images(
+ images,
+ gen_length,
+ n_candidates=n_candidates,
+ temperature=temperature,
+ top_p=top_p,
+ )
+ return {
+ output_gallery: new_images,
+ generate_btn: gr.update(interactive=True),
+ }
+
+ upload_drag.upload(
+ upload_image_fn,
+ inputs=[upload_drag, image_list],
+ outputs=[upload_drag, image_list, input_gallery, generate_btn],
+ )
+ clear_btn.click(
+ clear_fn,
+ inputs=None,
+ outputs=[image_list, input_gallery, generate_btn, output_gallery],
+ )
+ generate_btn.click(
+ disable_generate_btn,
+ inputs=None,
+ outputs=[generate_btn],
+ ).then(
+ generate_fn,
+ inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider],
+ outputs=[output_gallery, generate_btn],
+ )
+
+ example_groups = load_example_image_groups('prompts')
+
+ def add_image_group_fn(group_name, images):
+ new_images = images + example_groups[group_name]
+ return {
+ image_list: new_images,
+ input_gallery: new_images,
+ generate_btn: gr.update(interactive=True),
+ }
+
+ for group_name, group_images in example_groups.items():
+ with gr.Row():
+ with gr.Column(scale=3):
+ add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button'])
+ with gr.Column(scale=7):
+ group_gallery = gr.Gallery(
+ value=[Image.fromarray(np.array(img)) for img in group_images],
+ columns=5,
+ rows=1,
+ object_fit='scale-down',
+ label=group_name,
+ elem_classes=['large-gallery'],
+ )
+
+ add_button.click(
+ add_image_group_fn,
+ inputs=[gr.State(group_name), image_list],
+ outputs=[image_list, input_gallery, generate_btn],
+ )
+
+ demo.launch()
+
+if __name__ == "__main__":
+ mlxu.run(main)
+
diff --git a/batch_generation.py b/batch_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c044af5d192ad050be4199463085cb8df214dc8
--- /dev/null
+++ b/batch_generation.py
@@ -0,0 +1,223 @@
+"""
+Batch generation for sequnce of images. This script accept a jsonl file
+as input. Each line of the jsonl file representing a dictionary. Each line
+represents one example in the evaluation set. The dictionary should have two key:
+
+ input: a list of paths to the input images as context to the model.
+ output: a string representing the path to the output of generation to be saved.
+
+Ths script runs the mode to generate the output images, and concatenate the
+input and output images together and save them to the output path.
+"""
+
+import os
+import json
+from PIL import Image
+import numpy as np
+import mlxu
+from tqdm import tqdm, trange
+from multiprocessing import Pool
+import einops
+import torch
+
+from .inference import MultiProcessInferenceModel
+from .utils import read_image_to_tensor, MultiProcessImageSaver
+
+
+FLAGS, _ = mlxu.define_flags_with_default(
+ input_file='',
+ checkpoint='',
+ input_base_dir='',
+ output_base_dir='',
+ evaluate_mse=False,
+ json_input_key='input',
+ json_output_key='output',
+ json_target_key='target',
+ n_new_frames=1,
+ n_candidates=2,
+ context_frames=16,
+ temperature=1.0,
+ top_p=1.0,
+ n_workers=8,
+ dtype='float16',
+ torch_devices='',
+ batch_size_factor=4,
+ max_examples=0,
+ resize_output='',
+ include_input=False,
+)
+
+# create this according to the json file.
+class MultiFrameDataset(torch.utils.data.Dataset):
+ def __init__(self, input_files, output_files, target_files=None):
+ assert len(input_files)
+ self.input_files = input_files
+ self.output_files = output_files
+ self.target_files = target_files
+
+ def __len__(self):
+ return len(self.input_files)
+
+ def __getitem__(self, idx):
+ original_size = Image.open(self.input_files[idx][-1]).size
+ input_images = np.stack(
+ [read_image_to_tensor(f) for f in self.input_files[idx]],
+ axis=0
+ )
+
+ if self.target_files is not None:
+ target_images = np.stack(
+ [read_image_to_tensor(f) for f in self.target_files[idx]],
+ axis=0
+ )
+ else:
+ target_images = None
+ return input_images, target_images, self.output_files[idx], np.array(original_size)
+
+
+def main(_):
+ assert FLAGS.checkpoint != ''
+
+ print(f'Loading checkpoint from {FLAGS.checkpoint}')
+ print(f'Evaluating input file from {FLAGS.input_file}')
+
+ # build a model.
+
+ model = MultiProcessInferenceModel(
+ checkpoint=FLAGS.checkpoint,
+ torch_devices=FLAGS.torch_devices,
+ dtype=FLAGS.dtype,
+ context_frames=FLAGS.context_frames,
+ use_lock=True,
+ )
+
+ # input_files: the json file that needs to be generated by the other file.
+ input_files = []
+ output_files = []
+
+ if FLAGS.evaluate_mse:
+ target_files = []
+ else:
+ target_files = None
+
+ with mlxu.open_file(FLAGS.input_file, 'r') as f:
+ for line in f:
+ record = json.loads(line)
+ input_files.append(record[FLAGS.json_input_key])
+ output_files.append(record[FLAGS.json_output_key])
+ if FLAGS.evaluate_mse:
+ target_files.append(record[FLAGS.json_target_key])
+
+
+ if FLAGS.max_examples > 0:
+ input_files = input_files[:FLAGS.max_examples]
+ output_files = output_files[:FLAGS.max_examples]
+ if FLAGS.evaluate_mse:
+ target_files = target_files[:FLAGS.max_examples]
+
+ if FLAGS.input_base_dir != '':
+ input_files = [
+ [os.path.join(FLAGS.input_base_dir, x) for x in y]
+ for y in input_files
+ ]
+ if FLAGS.evaluate_mse:
+ target_files = [
+ [os.path.join(FLAGS.input_base_dir, x) for x in y]
+ for y in target_files
+ ]
+
+ if FLAGS.output_base_dir != '':
+ os.makedirs(FLAGS.output_base_dir, exist_ok=True)
+ output_files = [
+ os.path.join(FLAGS.output_base_dir, x)
+ for x in output_files
+ ]
+
+ dataset = MultiFrameDataset(input_files, output_files, target_files)
+
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size_factor * model.n_processes,
+ shuffle=False,
+ num_workers=FLAGS.n_workers,
+ )
+
+ image_saver = MultiProcessImageSaver(FLAGS.n_workers)
+
+ mses = []
+
+ for batch_images, batch_targets, batch_output_files, batch_sizes in tqdm(data_loader, ncols=0):
+
+ # batch_images is input.
+ batch_images = batch_images.numpy()
+
+ #
+ context_length = batch_images.shape[1]
+
+
+ generated_images = model(
+ batch_images,
+ FLAGS.n_new_frames,
+ FLAGS.n_candidates,
+ temperature=FLAGS.temperature,
+ top_p=FLAGS.top_p
+ )
+
+
+ repeated_batch = einops.repeat(
+ batch_images,
+ 'b s h w c -> b n s h w c',
+ n=FLAGS.n_candidates,
+ )
+ generated_images = np.array(generated_images)
+
+ if FLAGS.evaluate_mse:
+ batch_targets = einops.repeat(
+ batch_targets.numpy(),
+ 'b s h w c -> b n s h w c', # batch, candidate, s
+ n=FLAGS.n_candidates,
+ )
+ channels = batch_targets.shape[-1]
+ # calculate mse loss.
+ mse = np.mean((generated_images - batch_targets) ** 2, axis=(1, 2, 3, 4, 5))
+
+ mses.append(mse * channels)
+
+
+ if FLAGS.include_input:
+ combined = einops.rearrange(
+ np.concatenate([repeated_batch, generated_images], axis=2),
+ 'b n s h w c -> b (n h) (s w) c'
+ )
+ else:
+ combined = einops.rearrange(
+ generated_images,
+ 'b n s h w c -> b (n h) (s w) c'
+ )
+ combined = (combined * 255).astype(np.uint8)
+
+ n_frames = FLAGS.n_new_frames
+ if FLAGS.include_input:
+ n_frames += context_length
+
+ if FLAGS.resize_output == '':
+ resizes = None
+
+ elif FLAGS.resize_output == 'original':
+ resizes = batch_sizes.numpy()
+ resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]])
+ else:
+ resize = tuple(int(x) for x in FLAGS.resize_output.split(','))
+ resizes = np.array([resize] * len(batch_sizes))
+ resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]])
+
+ image_saver(combined, batch_output_files, resizes)
+
+ if FLAGS.evaluate_mse:
+ mses = np.concatenate(mses, axis=0)
+ print(f'MSE: {np.mean(mses)}')
+
+ image_saver.close()
+
+if __name__ == "__main__":
+ mlxu.run(main)
\ No newline at end of file
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..57060a49b8f8b92ef4243154e841268c5a3fbc28
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,263 @@
+import re
+from natsort import natsorted
+
+def natural_sort_key(s):
+ return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
+
+def load_example_image_groups(directory):
+ example_groups = {}
+ for subdir in os.listdir(directory):
+ subdir_path = os.path.join(directory, subdir)
+ if os.path.isdir(subdir_path):
+ example_groups[subdir] = []
+ images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+ images = natsorted(images, key=natural_sort_key) # Natural sorting
+ for filename in images:
+ img = Image.open(os.path.join(subdir_path, filename))
+ example_groups[subdir].append(img)
+ return example_groups
+
+
+from io import BytesIO
+import gradio as gr
+import uvicorn
+from fastapi import FastAPI
+from PIL import Image
+import numpy as np
+import mlxu
+import os
+import re
+from natsort import natsorted
+
+from .inference import MultiProcessInferenceModel
+
+FLAGS, _ = mlxu.define_flags_with_default(
+ host='0.0.0.0',
+ port=5007,
+ dtype='float16',
+ checkpoint='',
+ torch_devices='',
+ context_frames=16,
+)
+
+def natural_sort_key(s):
+ return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
+
+def load_example_image_groups(directory):
+ example_groups = {}
+ for subdir in os.listdir(directory):
+ subdir_path = os.path.join(directory, subdir)
+ if os.path.isdir(subdir_path):
+ example_groups[subdir] = []
+ images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+ images = natsorted(images, key=natural_sort_key) # Natural sorting
+ for filename in images:
+ img = Image.open(os.path.join(subdir_path, filename))
+ example_groups[subdir].append(img)
+ return example_groups
+
+def main(_):
+ assert FLAGS.checkpoint != ''
+
+ model = MultiProcessInferenceModel(
+ checkpoint=FLAGS.checkpoint,
+ torch_devices=FLAGS.torch_devices,
+ dtype=FLAGS.dtype,
+ context_frames=FLAGS.context_frames,
+ use_lock=True,
+ )
+
+ checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1)
+ checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1)
+ checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32)
+
+ def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9):
+ assert len(input_images) > 0
+ input_images = [
+ np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0
+ for img in input_images
+ ]
+ input_images = np.stack(input_images, axis=0)
+ output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0]
+
+ generated_images = []
+ for candidate in output_images:
+ concatenated_image = []
+ for i, img in enumerate(candidate):
+ concatenated_image.append(img)
+ if i < len(candidate) - 1:
+ concatenated_image.append(checkerboard)
+ generated_images.append(
+ Image.fromarray(
+ (np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8)
+ )
+ )
+
+ return generated_images
+
+ with gr.Blocks(css="""
+ .small-button {
+ padding: 5px 10px;
+ min-width: 80px;
+ }
+ .large-gallery img {
+ width: 100%;
+ height: auto;
+ max-height: 150px;
+ }
+ """) as demo:
+ with gr.Column():
+ image_list = gr.State([])
+ gr.Markdown('# LVM Demo')
+ gr.Markdown(f'Serving model: {FLAGS.checkpoint}')
+ gr.Markdown('## Inputs')
+ with gr.Row():
+ upload_drag = gr.File(
+ type='binary',
+ file_types=['image'],
+ file_count='multiple',
+ )
+ with gr.Column():
+ gen_length_slider = gr.Slider(
+ label='Generation length',
+ minimum=1,
+ maximum=32,
+ value=1,
+ step=1,
+ interactive=True,
+ )
+ n_candidates_slider = gr.Slider(
+ label='Number of candidates',
+ minimum=1,
+ maximum=10,
+ value=1,
+ step=1,
+ interactive=True,
+ )
+ temp_slider = gr.Slider(
+ label='Temperature',
+ minimum=0,
+ maximum=2.0,
+ value=1.0,
+ interactive=True,
+ )
+ top_p_slider = gr.Slider(
+ label='Top p',
+ minimum=0,
+ maximum=1.0,
+ value=0.9,
+ interactive=True,
+ )
+ clear_btn = gr.Button(
+ value='Clear',
+ elem_classes=['small-button'],
+ )
+ generate_btn = gr.Button(
+ value='Generate',
+ interactive=False,
+ elem_classes=['small-button'],
+ )
+ input_gallery = gr.Gallery(
+ columns=7,
+ rows=1,
+ object_fit='scale-down',
+ )
+ gr.Markdown('## Outputs')
+ output_gallery = gr.Gallery(
+ columns=4,
+ object_fit='scale-down',
+ )
+
+ def upload_image_fn(files, images):
+ for file in files:
+ images.append(Image.open(BytesIO(file)))
+
+ return {
+ upload_drag: None,
+ image_list: images,
+ input_gallery: images,
+ generate_btn: gr.update(interactive=True),
+ }
+
+ def clear_fn():
+ return {
+ image_list: [],
+ input_gallery: [],
+ generate_btn: gr.update(interactive=False),
+ output_gallery: [],
+ }
+
+ def disable_generate_btn():
+ return {
+ generate_btn: gr.update(interactive=False),
+ }
+
+ def generate_fn(images, n_candidates, gen_length, temperature, top_p):
+ new_images = generate_images(
+ images,
+ gen_length,
+ n_candidates=n_candidates,
+ temperature=temperature,
+ top_p=top_p,
+ )
+ return {
+ output_gallery: new_images,
+ generate_btn: gr.update(interactive=True),
+ }
+
+ upload_drag.upload(
+ upload_image_fn,
+ inputs=[upload_drag, image_list],
+ outputs=[upload_drag, image_list, input_gallery, generate_btn],
+ )
+ clear_btn.click(
+ clear_fn,
+ inputs=None,
+ outputs=[image_list, input_gallery, generate_btn, output_gallery],
+ )
+ generate_btn.click(
+ disable_generate_btn,
+ inputs=None,
+ outputs=[generate_btn],
+ ).then(
+ generate_fn,
+ inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider],
+ outputs=[output_gallery, generate_btn],
+ )
+
+ example_groups = load_example_image_groups('/home/yutongbai/demo_images')
+
+ def add_image_group_fn(group_name, images):
+ new_images = images + example_groups[group_name]
+ return {
+ image_list: new_images,
+ input_gallery: new_images,
+ generate_btn: gr.update(interactive=True),
+ }
+
+ for group_name, group_images in example_groups.items():
+ with gr.Row():
+ with gr.Column(scale=3):
+ add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button'])
+ with gr.Column(scale=7):
+ group_gallery = gr.Gallery(
+ value=[Image.fromarray(np.array(img)) for img in group_images],
+ columns=5,
+ rows=1,
+ object_fit='scale-down',
+ label=group_name,
+ elem_classes=['large-gallery'],
+ )
+
+ add_button.click(
+ add_image_group_fn,
+ inputs=[gr.State(group_name), image_list],
+ outputs=[image_list, input_gallery, generate_btn],
+ )
+
+ app = FastAPI()
+ app = gr.mount_gradio_app(app, demo, '/')
+ uvicorn.run(app, host=FLAGS.host, port=FLAGS.port)
+
+if __name__ == "__main__":
+ mlxu.run(main)
diff --git a/eval_perplexity.py b/eval_perplexity.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bdefb882fc039d294209f0bb5e254ede100f340
--- /dev/null
+++ b/eval_perplexity.py
@@ -0,0 +1,127 @@
+"""
+Evaluating the perplexity on few shot tasks. This script accept a jsonl file
+as input. Each line of the jsonl file representing a dictionary. Each line
+represents one example in the evaluation set. The dictionary should have two key:
+
+ input: a list of paths to the input images as context to the model. This
+ list should include the few shot examples.
+ target: a list of paths to the target images to evaluate perplexity
+
+Ths script should run the model and compute the average perplexity on the
+evaluation set.
+"""
+
+import os
+import json
+from PIL import Image
+import numpy as np
+import mlxu
+from tqdm import tqdm, trange
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import einops
+
+from .inference import MultiProcessInferenceModel
+
+
+FLAGS, _ = mlxu.define_flags_with_default(
+ input_file='',
+ checkpoint='',
+ input_base_dir='',
+ batch_size=2,
+ json_input_key='input',
+ json_target_key='target',
+ dtype='float16',
+ torch_devices='',
+ n_workers=4,
+ max_examples=0,
+)
+
+
+def read_image_to_tensor(path):
+ pil_im = Image.open(path).convert('RGB')
+ input_img = pil_im.resize((256, 256))
+ input_img = np.array(input_img) / 255.0
+ input_img = input_img.astype(np.float32)
+ return input_img
+
+
+class MultiFrameDataset(torch.utils.data.Dataset):
+ def __init__(self, input_files, target_files):
+ assert len(input_files) == len(target_files)
+ self.input_files = input_files
+ self.target_files = target_files
+
+ def __len__(self):
+ return len(self.input_files)
+
+ def __getitem__(self, idx):
+ input_list = np.stack(
+ [read_image_to_tensor(f) for f in self.input_files[idx]],
+ axis=0
+ )
+ target_list = np.stack(
+ [read_image_to_tensor(f) for f in self.target_files[idx]],
+ axis=0
+ )
+ return input_list, target_list
+
+
+def main(_):
+ assert FLAGS.checkpoint != ''
+
+ print(f'Loading checkpoint from {FLAGS.checkpoint}')
+ print(f'Evaluating input file from {FLAGS.input_file}')
+
+ model = MultiProcessInferenceModel(
+ checkpoint=FLAGS.checkpoint,
+ torch_devices=FLAGS.torch_devices,
+ dtype=FLAGS.dtype,
+ use_lock=True,
+ perplexity_batch_size=FLAGS.batch_size,
+ )
+
+ input_files = []
+ target_files = []
+
+ with mlxu.open_file(FLAGS.input_file, 'r') as f:
+ for line in f:
+ record = json.loads(line)
+ input_files.append(record[FLAGS.json_input_key])
+ target_files.append(record[FLAGS.json_target_key])
+
+ if FLAGS.input_base_dir != '':
+ input_files = [
+ [os.path.join(FLAGS.input_base_dir, x) for x in y]
+ for y in input_files
+ ]
+ target_files = [
+ [os.path.join(FLAGS.input_base_dir, x) for x in y]
+ for y in target_files
+ ]
+
+ if FLAGS.max_examples > 0:
+ input_files = input_files[:FLAGS.max_examples]
+ target_files = target_files[:FLAGS.max_examples]
+
+ dataset = MultiFrameDataset(input_files, target_files)
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size * model.n_processes,
+ shuffle=False,
+ num_workers=FLAGS.n_workers
+ )
+
+ perplexities = []
+
+ for input_images, target_images in tqdm(data_loader, ncols=0):
+ perplexity = model.compute_perplexity(input_images, target_images)
+ perplexities.append(perplexity)
+
+ perplexities = np.concatenate(perplexities, axis=0)
+ print(f'Perplexity: {np.mean(perplexities)}')
+
+
+if __name__ == "__main__":
+ mlxu.run(main)
\ No newline at end of file
diff --git a/eval_video_perplexity.py b/eval_video_perplexity.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ba70ee8f53089c9bc456dcf9ca4dda0faade1f
--- /dev/null
+++ b/eval_video_perplexity.py
@@ -0,0 +1,134 @@
+
+import os
+import glob
+from functools import partial
+from tqdm import tqdm, trange
+from multiprocessing import Pool
+from PIL import Image
+import cv2
+import mlxu
+from natsort import natsorted
+import numpy as np
+import einops
+import torch
+
+from vqlm_demo.inference import MultiProcessInferenceModel
+from vqlm_demo.utils import (
+ is_video, random_square_crop,
+ read_frames_from_dir, read_frames_from_video
+)
+
+
+FLAGS, _ = mlxu.define_flags_with_default(
+ checkpoint='',
+ input_files='',
+ frame_input=False,
+ read_file_list='',
+ center_crop=1.0,
+ n_context_frames=15,
+ n_target_frames=1,
+ n_workers=8,
+ stride=8,
+ batch_size=2,
+ torch_devices='',
+ shuffle=False,
+ random_start=True,
+ max_examples=0,
+)
+
+
+class VideoDataset(torch.utils.data.Dataset):
+
+ def __init__(self, videos, frame_input=False, n_context_frames=15,
+ n_target_frames=1, stride=1):
+ self.videos = videos
+ self.frame_input = frame_input
+ self.n_context_frames = n_context_frames
+ self.n_target_frames = n_target_frames
+ self.stride = stride
+
+ def __getitem__(self, index):
+ if self.frame_input:
+ frames = read_frames_from_dir(
+ self.videos[index],
+ self.n_context_frames + self.n_target_frames,
+ self.stride,
+ center_crop=FLAGS.center_crop,
+ random_start=FLAGS.random_start,
+ )
+ else:
+ frames = read_frames_from_video(
+ self.videos[index],
+ self.n_context_frames + self.n_target_frames,
+ self.stride,
+ center_crop=FLAGS.center_crop,
+ random_start=FLAGS.random_start,
+ )
+ if frames is None:
+ return self[np.random.randint(0, len(self))]
+ return frames[:self.n_context_frames], frames[self.n_context_frames:]
+
+ def __len__(self):
+ return len(self.videos)
+
+
+
+def main(_):
+ assert FLAGS.checkpoint != ''
+ assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
+
+ model = MultiProcessInferenceModel(
+ checkpoint=FLAGS.checkpoint,
+ torch_devices=FLAGS.torch_devices,
+ perplexity_batch_size=FLAGS.batch_size,
+ )
+
+ if FLAGS.read_file_list != '':
+ with open(FLAGS.read_file_list, 'r') as f:
+ videos = [x.strip() for x in f.readlines()]
+ else:
+ videos = glob.glob(FLAGS.input_files)
+
+ if FLAGS.frame_input:
+ videos = [x for x in videos if os.path.isdir(x)]
+ else:
+ videos = [x for x in videos if is_video(x)]
+
+ if FLAGS.shuffle:
+ np.random.shuffle(videos)
+
+ if FLAGS.max_examples > 0:
+ videos = videos[:FLAGS.max_examples]
+
+ dataset = VideoDataset(
+ videos,
+ frame_input=FLAGS.frame_input,
+ n_context_frames=FLAGS.n_context_frames,
+ n_target_frames=FLAGS.n_target_frames,
+ stride=FLAGS.stride
+ )
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size * model.n_processes * 4,
+ shuffle=False,
+ num_workers=FLAGS.n_workers,
+ prefetch_factor=4,
+ drop_last=True,
+ )
+
+ perplexities = []
+
+ for batch_context_frames, batch_taret_frames in tqdm(dataloader, ncols=0):
+ batch_context_frames = batch_context_frames.numpy()
+ batch_taret_frames = batch_taret_frames.numpy()
+ perplexity = model.compute_perplexity(
+ batch_context_frames, batch_taret_frames
+ )
+ perplexities.append(perplexity)
+
+ perplexities = np.concatenate(perplexities, axis=0)
+ print(f'Perplexity: {np.mean(perplexities)}')
+
+
+if __name__ == '__main__':
+ mlxu.run(main)
\ No newline at end of file
diff --git a/eval_videos.py b/eval_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..4822e54d1de6e622eb5e2ccd6212f12c5130ea7a
--- /dev/null
+++ b/eval_videos.py
@@ -0,0 +1,160 @@
+import os
+import glob
+from functools import partial
+from tqdm import tqdm, trange
+from multiprocessing import Pool
+from PIL import Image
+import cv2
+import mlxu
+from natsort import natsorted
+import numpy as np
+import einops
+import torch
+
+from vqlm_demo.inference import MultiProcessInferenceModel
+from vqlm_demo.utils import (
+ is_video, random_square_crop,
+ read_frames_from_dir, read_frames_from_video
+)
+
+
+FLAGS, _ = mlxu.define_flags_with_default(
+ checkpoint='',
+ input_files='',
+ frame_input=False,
+ read_file_list='',
+ output_dir='',
+ center_crop=1.0,
+ n_context_frames=12,
+ n_new_frames=4,
+ n_candidates=8,
+ temperature=1.0,
+ top_p=1.0,
+ n_workers=8,
+ stride=8,
+ batch_size=32,
+ torch_devices='',
+ shuffle=False,
+ max_examples=0,
+)
+
+
+def save_image(args):
+ image, filename = args
+ base = FLAGS.input_files.split('*')[0]
+ filename = filename[len(base):].replace('/', '_') + '.png'
+ Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename))
+
+
+class VideoDataset(torch.utils.data.Dataset):
+
+ def __init__(self, videos, frame_input=False, n_frames=8, stride=1, new_frame=1):
+ self.videos = videos
+ self.frame_input = frame_input
+ self.n_frames = n_frames
+ self.stride = stride
+ self.new_frames = new_frames
+
+ def __getitem__(self, index):
+ if self.frame_input:
+ frames = read_frames_from_dir(
+ self.videos[index], self.n_frames, self.stride,
+ center_crop=FLAGS.center_crop,
+ )
+
+ else:
+ # 's h w c'
+ frames = read_frames_from_video(
+ self.videos[index], self.n_frames, self.stride,
+ center_crop=FLAGS.center_crop,
+ )
+ target_frames = frames[n_frames-new_frame:n_frames, :, :, :]
+
+ if frames is None:
+ return self[np.random.randint(0, len(self))]
+
+
+ return frames, target_frames, self.videos[index]
+
+ def __len__(self):
+ return len(self.videos)
+
+
+
+def main(_):
+ assert FLAGS.checkpoint != '' and FLAGS.output_dir != ''
+ assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
+ os.makedirs(FLAGS.output_dir, exist_ok=True)
+
+ if FLAGS.read_file_list != '':
+ with open(FLAGS.read_file_list, 'r') as f:
+ videos = [x.strip() for x in f.readlines()]
+ else:
+ videos = glob.glob(FLAGS.input_files)
+
+ if FLAGS.frame_input:
+ videos = [x for x in videos if os.path.isdir(x)]
+ else:
+ videos = [x for x in videos if is_video(x)]
+
+ if FLAGS.shuffle:
+ np.random.shuffle(videos)
+
+ if FLAGS.max_examples > 0:
+ videos = videos[:FLAGS.max_examples]
+
+ dataset = VideoDataset(
+ videos,
+ frame_input=FLAGS.frame_input,
+ n_frames=FLAGS.n_context_frames,
+ stride=FLAGS.stride
+ )
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size,
+ shuffle=False,
+ num_workers=FLAGS.n_workers,
+ prefetch_factor=4,
+ drop_last=True,
+ )
+
+ if FLAGS.torch_devices == '':
+ torch_devices = None
+ else:
+ torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')]
+
+ model = MultiProcessInferenceModel(
+ checkpoint=FLAGS.checkpoint, torch_devices=torch_devices,
+ )
+
+ save_img_pool = Pool(FLAGS.n_workers)
+
+
+ fids
+
+ for batch, batch_targets, filenames in tqdm(dataloader, ncols=0):
+
+ batch = batch.numpy() # 'b s h w c '
+
+
+
+ generated = model(
+ batch,
+ n_new_frames=FLAGS.n_new_frames,
+ n_candidates=FLAGS.n_candidates,
+ temperature=FLAGS.temperature,
+ top_p=FLAGS.top_p,
+ )
+
+
+ generated = np.array(generated)
+
+ batch_targets = einops.repeat(
+ batch_targets.numpy(),
+ 'b s h w c -> b n s h w c', # batch, candidate, sequence, h, w, c.
+ n=FLAGS.n_candidates,
+ )
+
+
+if __name__ == '__main__':
+ mlxu.run(main)
\ No newline at end of file
diff --git a/generate_videos.py b/generate_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..00f6d129adb843cd53723f32e16555a3108cda76
--- /dev/null
+++ b/generate_videos.py
@@ -0,0 +1,168 @@
+
+import os
+import glob
+from functools import partial
+from tqdm import tqdm, trange
+from multiprocessing import Pool
+from PIL import Image
+import cv2
+import mlxu
+from natsort import natsorted
+import numpy as np
+import einops
+import torch
+
+from vqlm_demo.inference import MultiProcessInferenceModel
+from vqlm_demo.utils import (
+ is_video, random_square_crop,
+ read_frames_from_dir, read_frames_from_video
+)
+
+
+FLAGS, _ = mlxu.define_flags_with_default(
+ checkpoint='',
+ input_files='',
+ frame_input=False,
+ read_file_list='',
+ output_dir='',
+ center_crop=1.0,
+ n_context_frames=12,
+ n_new_frames=4,
+ n_candidates=8,
+ temperature=1.0,
+ top_p=1.0,
+ n_workers=8,
+ stride=8,
+ batch_size=32,
+ torch_devices='',
+ shuffle=False,
+ max_examples=0,
+)
+
+
+def save_image(args):
+ image, filename = args
+ base = FLAGS.input_files.split('*')[0]
+ filename = filename[len(base):].replace('/', '_') + '.png'
+ Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename))
+
+
+class VideoDataset(torch.utils.data.Dataset):
+
+ def __init__(self, videos, frame_input=False, n_frames=8, stride=1):
+ self.videos = videos
+ self.frame_input = frame_input
+ self.n_frames = n_frames
+ self.stride = stride
+
+ def __getitem__(self, index):
+ if self.frame_input:
+ frames = read_frames_from_dir(
+ self.videos[index], self.n_frames, self.stride,
+ center_crop=FLAGS.center_crop,
+ )
+ else:
+ frames = read_frames_from_video(
+ self.videos[index], self.n_frames, self.stride,
+ center_crop=FLAGS.center_crop,
+ )
+ if frames is None:
+ return self[np.random.randint(0, len(self))]
+ return frames, self.videos[index]
+
+ def __len__(self):
+ return len(self.videos)
+
+
+
+def main(_):
+ assert FLAGS.checkpoint != '' and FLAGS.output_dir != ''
+ assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
+ os.makedirs(FLAGS.output_dir, exist_ok=True)
+
+ if FLAGS.read_file_list != '':
+ with open(FLAGS.read_file_list, 'r') as f:
+ videos = [x.strip() for x in f.readlines()]
+ else:
+ videos = glob.glob(FLAGS.input_files)
+
+ if FLAGS.frame_input:
+ videos = [x for x in videos if os.path.isdir(x)]
+ else:
+ videos = [x for x in videos if is_video(x)]
+
+ if FLAGS.shuffle:
+ np.random.shuffle(videos)
+
+ if FLAGS.max_examples > 0:
+ videos = videos[:FLAGS.max_examples]
+
+ dataset = VideoDataset(
+ videos,
+ frame_input=FLAGS.frame_input,
+ n_frames=FLAGS.n_context_frames,
+ stride=FLAGS.stride
+ )
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size,
+ shuffle=False,
+ num_workers=FLAGS.n_workers,
+ prefetch_factor=4,
+ drop_last=True,
+ )
+
+ if FLAGS.torch_devices == '':
+ torch_devices = None
+ else:
+ torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')]
+
+ model = MultiProcessInferenceModel(
+ checkpoint=FLAGS.checkpoint, torch_devices=torch_devices,
+ )
+
+ save_img_pool = Pool(FLAGS.n_workers)
+
+
+
+ for batch, filenames in tqdm(dataloader, ncols=0):
+
+
+
+ batch = batch.numpy()
+
+
+
+ generated = model(
+ batch,
+ n_new_frames=FLAGS.n_new_frames,
+ n_candidates=FLAGS.n_candidates,
+ temperature=FLAGS.temperature,
+ top_p=FLAGS.top_p,
+ )
+
+
+ generated = np.array(generated)
+
+
+
+
+ output_batch = einops.repeat(
+ batch,
+ 'b s h w c -> b n s h w c',
+ n=FLAGS.n_candidates,
+ )
+
+
+ combined = einops.rearrange(
+ np.concatenate([output_batch, generated], axis=2),
+ 'b n s h w c -> b (n h) (s w) c'
+ )
+
+
+ combined = (np.clip(combined, 0, 1) * 255).astype(np.uint8)
+ save_img_pool.imap(save_image, zip(combined, filenames))
+
+
+if __name__ == '__main__':
+ mlxu.run(main)
\ No newline at end of file
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..73b1b816a632da1e9fa2ec70b608c1d04fc54e1a
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,240 @@
+from abc import ABC, abstractmethod
+from contextlib import nullcontext
+import time
+import os
+from functools import partial
+from copy import deepcopy
+from multiprocessing import Pool
+from threading import Lock
+from PIL import Image
+import numpy as np
+import torch
+import torch.nn.functional as F
+import einops
+from transformers import LlamaForCausalLM
+import spaces
+
+from vqvae_muse import VQGANModel, get_tokenizer_muse
+from torch_vqvae_model import get_tokenizer
+
+
+def get_torch_float_dtype(dtype):
+ if dtype in (torch.float16, torch.bfloat16, torch.float32):
+ return dtype
+ return {
+ 'float16': torch.float16,
+ 'fp16': torch.float16,
+ 'f16': torch.float16,
+ 'bfloat16': torch.bfloat16,
+ 'bf16': torch.bfloat16,
+ 'float32': torch.float32,
+ 'fp32': torch.float32,
+ 'f32': torch.float32,
+ }[dtype]
+
+
+def get_pid():
+ time.sleep(1)
+ return os.getpid()
+
+
+class InferenceModel(ABC):
+
+ @abstractmethod
+ def __call__(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
+ raise NotImplementedError()
+
+
+class LocalInferenceModel(InferenceModel):
+
+ def __init__(self, checkpoint, dtype='float16', torch_device='cuda',
+ context_frames=16, use_lock=False):
+ self.checkpoint = checkpoint
+ self.dtype = dtype
+ self.torch_device = torch_device
+ self.context_frames = context_frames
+
+ # new tokenizer
+ self.tokenizer = get_tokenizer_muse()
+ self.tokenizer.to(self.torch_device)
+
+ self.model = LlamaForCausalLM.from_pretrained(
+ self.checkpoint, torch_dtype=get_torch_float_dtype(self.dtype)
+ ).to(self.torch_device)
+ print("torch device", self.torch_device)
+ print("init device", self.model.device)
+
+ if use_lock:
+ self.lock = Lock()
+ else:
+ self.lock = nullcontext()
+
+ @torch.no_grad()
+ def compute_perplexity(self, input_images, target_images):
+ input_images = np.array(input_images)
+ target_images = np.array(target_images)
+ assert len(input_images.shape) == 5 and len(target_images.shape) == 5 # [B, S, H, W, C]
+ assert input_images.shape[0] == target_images.shape[0]
+ batch_size = input_images.shape[0]
+ with self.lock:
+ input_images = torch.tensor(
+ einops.rearrange(input_images, 'b s h w c -> b s c h w')
+ ).to(self.torch_device)
+ target_images = torch.tensor(
+ einops.rearrange(target_images, 'b s h w c -> b s c h w')
+ ).to(self.torch_device)
+ input_ids = self.tokenizer.tokenize(input_images).view(batch_size, -1)
+ target_ids = self.tokenizer.tokenize(target_images).view(batch_size, -1)
+ all_ids = torch.cat([input_ids, target_ids], dim=1)
+ logits = self.model(all_ids).logits
+ log_probs = F.log_softmax(logits, dim=-1)
+ target_ids_onehot = F.one_hot(target_ids, num_classes=logits.shape[-1])
+ target_log_probs = log_probs[:, input_ids.shape[1] - 1 : -1]
+ perplexity = torch.exp(
+ -torch.mean(
+ torch.sum(target_log_probs * target_ids_onehot, dim=-1),
+ dim=-1
+ )
+ )
+ return perplexity.detach().cpu().numpy()
+
+ @torch.no_grad()
+ def generate_once(self, input_images, n_new_frames, temperature=1.0, top_p=1.0):
+ assert type(input_images) == np.ndarray
+ with self.lock:
+ input_images = np.array(input_images, dtype=np.float32)
+ input_images = torch.tensor(
+ einops.rearrange(input_images, 'b h w c -> b c h w')
+ ).to(self.torch_device)
+
+ # not quite sure why i need to redo it here
+ self.model.to(self.torch_device)
+ self.tokenizer.to(self.torch_device)
+
+ # new tokenizer
+ _, input_ids = self.tokenizer.encode(input_images)
+ input_ids = input_ids.view(1, -1)
+
+
+ input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
+
+ new_tokens = []
+ current_context_frames = input_ids.shape[1] // 256
+ fisrt_generation_left = self.context_frames - current_context_frames
+ first_new_frames = min(fisrt_generation_left, n_new_frames)
+ input_ids = self.model.generate(
+ input_ids=input_ids,
+ attention_mask=torch.ones_like(input_ids),
+ pad_token_id=8192,
+ max_new_tokens=256 * first_new_frames,
+ do_sample=True,
+ top_p=top_p,
+ temperature=temperature,
+ suppress_tokens=list(range(8192, self.model.vocab_size)),
+ )
+ new_tokens.append(input_ids[:, -256 * first_new_frames:])
+ input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
+
+ for _ in range(max(0, n_new_frames - first_new_frames)):
+ input_ids = self.model.generate(
+ input_ids=input_ids,
+ attention_mask=torch.ones_like(input_ids),
+ pad_token_id=8192,
+ max_new_tokens=256,
+ do_sample=True,
+ top_p=top_p,
+ temperature=temperature,
+ suppress_tokens=list(range(8192, self.model.vocab_size)),
+ )
+ new_tokens.append(input_ids[:, -256:])
+ input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
+
+ new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256)
+ new_images = einops.rearrange(
+ torch.clamp(self.tokenizer.decode_code(new_tokens), 0.0, 1.0),
+ 'b c h w -> b h w c'
+ ).detach().cpu().numpy()
+ return new_images
+
+ @spaces.GPU(duration=180)
+ def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
+ output = []
+ for seq in input_images:
+ output.append(
+ [self.generate_once(seq, n_new_frames, temperature, top_p)
+ for _ in range(n_candidates)]
+ )
+ return output
+
+
+class MultiProcessInferenceModel(InferenceModel):
+
+ def __init__(self, checkpoint, torch_devices=None, dtype='float16',
+ context_frames=16, use_lock=False, perplexity_batch_size=2):
+ if torch_devices is None or torch_devices == '':
+ torch_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
+
+ self.torch_devices = torch_devices
+ self.n_processes = len(torch_devices)
+ print(f'Using {self.n_processes} processes for inference')
+ self.worker_pool = Pool(self.n_processes)
+ self.worker_pids = self.worker_pool.starmap(get_pid, [tuple() for _ in range(self.n_processes)])
+ self.device_map = {
+ pid: torch_device
+ for pid, torch_device in zip(self.worker_pids, self.torch_devices)
+ }
+ self.worker_pool.starmap(
+ self.initialize_worker,
+ [(self.device_map, checkpoint, dtype, context_frames) for _ in range(self.n_processes)]
+ )
+ self.perplexity_batch_size = perplexity_batch_size
+ if use_lock:
+ self.lock = Lock()
+ else:
+ self.lock = nullcontext()
+
+ @staticmethod
+ def initialize_worker(device_map, checkpoint, dtype, context_frames):
+ global _current_process_backend
+ torch_device = device_map[os.getpid()]
+ _current_process_backend = LocalInferenceModel(
+ checkpoint, dtype, torch_device, context_frames
+ )
+
+ @staticmethod
+ def generate_once(input_images, n_new_frames, temperature=1.0, top_p=1.0):
+ return _current_process_backend.generate_once(input_images, n_new_frames, temperature, top_p)
+
+ @staticmethod
+ def compute_perplexity_once(input_images, target_images):
+ return _current_process_backend.compute_perplexity(input_images, target_images)
+
+ def compute_perplexity(self, input_images, target_images):
+ with self.lock:
+ map_args = []
+ for i in range(0, len(input_images), self.perplexity_batch_size):
+ map_args.append((
+ input_images[i : i + self.perplexity_batch_size],
+ target_images[i : i + self.perplexity_batch_size]
+ ))
+ outputs = self.worker_pool.starmap(self.compute_perplexity_once, map_args)
+ return np.concatenate(outputs, axis=0)
+
+ def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
+ with self.lock:
+ map_args = []
+ for seq in input_images:
+ for _ in range(n_candidates):
+ map_args.append((seq, n_new_frames, temperature, top_p))
+
+ outputs = self.worker_pool.starmap(self.generate_once, map_args)
+ reshaped_output = []
+ index = 0
+ for _ in range(len(input_images)):
+ candidates = []
+ for _ in range(n_candidates):
+ candidates.append(outputs[index])
+ index += 1
+ reshaped_output.append(candidates)
+ return reshaped_output
+
diff --git a/prompts/.DS_Store b/prompts/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..3b3b755588b8880f623b777ccfaf36048c71b851
Binary files /dev/null and b/prompts/.DS_Store differ
diff --git a/prompts/Composition/Slide1.png b/prompts/Composition/Slide1.png
new file mode 100644
index 0000000000000000000000000000000000000000..c20dd2a265aef200d02fec3495af8cdb4fece30d
--- /dev/null
+++ b/prompts/Composition/Slide1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d926922e8e28f02c46e723b85d3d4969da271f892654ce492cf59cbf3f322a0
+size 194501
diff --git a/prompts/Composition/Slide10.png b/prompts/Composition/Slide10.png
new file mode 100644
index 0000000000000000000000000000000000000000..c8a0e88cbc36c13d575f223edd6681fe95f63a86
--- /dev/null
+++ b/prompts/Composition/Slide10.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d4480a3a7905b703ca1802e5391ea47e90e84cdc7eacb5229ade606ce4f5b6bb
+size 443693
diff --git a/prompts/Composition/Slide11.png b/prompts/Composition/Slide11.png
new file mode 100644
index 0000000000000000000000000000000000000000..0549fe330577e0adc697dc03fb284aa15b14f441
--- /dev/null
+++ b/prompts/Composition/Slide11.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:91cbe861bd47c4ec08e79bccdb64b993cc4b3b21549c346f834a985b1b0a1a6e
+size 464548
diff --git a/prompts/Composition/Slide12.png b/prompts/Composition/Slide12.png
new file mode 100644
index 0000000000000000000000000000000000000000..31116cc53413933baaaf93aa5c7a4373e713944d
--- /dev/null
+++ b/prompts/Composition/Slide12.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d05d2db2a5e7bc7e33795583e10cdc03ea53bacd250010680a161ab07b7ad65
+size 487835
diff --git a/prompts/Composition/Slide13.png b/prompts/Composition/Slide13.png
new file mode 100644
index 0000000000000000000000000000000000000000..f5c89f17f384cb855047917b3bdc589919cd4504
--- /dev/null
+++ b/prompts/Composition/Slide13.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d94cfad17df77fa90ab84bdd89d3ad09938a5fe768b4e211c2bac140b36c12cb
+size 489967
diff --git a/prompts/Composition/Slide14.png b/prompts/Composition/Slide14.png
new file mode 100644
index 0000000000000000000000000000000000000000..de90d3fa3d4b3af9d32fbce6803389b072d26322
--- /dev/null
+++ b/prompts/Composition/Slide14.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04b42409ec1ca2ddbde1114eb8426a34c5e0064159e224af808b766ae003d2fd
+size 492423
diff --git a/prompts/Composition/Slide15.png b/prompts/Composition/Slide15.png
new file mode 100644
index 0000000000000000000000000000000000000000..871bd579690ad66a0c714a1e5fdc33846aa9147c
--- /dev/null
+++ b/prompts/Composition/Slide15.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9919156ccdd9c2cbb30811529e94c83bb2afb277c90fed503ea4716be702cdde
+size 491891
diff --git a/prompts/Composition/Slide2.png b/prompts/Composition/Slide2.png
new file mode 100644
index 0000000000000000000000000000000000000000..dbe072dca6d9ef3fe4491d740af5df5c6b010c68
--- /dev/null
+++ b/prompts/Composition/Slide2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0c9f6467cc732b562c167770d38a164162e4454127a242a16e3bdae7e717d27
+size 193143
diff --git a/prompts/Composition/Slide3.png b/prompts/Composition/Slide3.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e8ac80de12ba6e1b5c4bf4e6bfa6ac7ebad7ad6
--- /dev/null
+++ b/prompts/Composition/Slide3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f702f10001fd9e7ad523753c884f8cef532da878d62656ffdbd566e104b67c7
+size 199394
diff --git a/prompts/Composition/Slide4.png b/prompts/Composition/Slide4.png
new file mode 100644
index 0000000000000000000000000000000000000000..7e0643b567bf4a0181ff14b6b954356c30ad7b06
--- /dev/null
+++ b/prompts/Composition/Slide4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:18c2d2384e4c97f35ae4cddc9bea4e600946eefefefff1f4fb683a51a54d4384
+size 202638
diff --git a/prompts/Composition/Slide5.png b/prompts/Composition/Slide5.png
new file mode 100644
index 0000000000000000000000000000000000000000..59c9ae2435567cdce73774b3ef5342d70a0f13da
--- /dev/null
+++ b/prompts/Composition/Slide5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4af292b97a2abe48d253fb2f1badd8d147402a3124fd12a2a0750307487c4f27
+size 190546
diff --git a/prompts/Composition/Slide6.png b/prompts/Composition/Slide6.png
new file mode 100644
index 0000000000000000000000000000000000000000..fb2d05758aadbd0ee184bc274ece6ede5714dd6c
--- /dev/null
+++ b/prompts/Composition/Slide6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8b5fe9521e4094950384fce57733496363750b6a7c816ebae3cf43e6bcdb626
+size 173097
diff --git a/prompts/Composition/Slide7.png b/prompts/Composition/Slide7.png
new file mode 100644
index 0000000000000000000000000000000000000000..aa5e53aa0367d64009edede3023ab3ab1cdfa196
--- /dev/null
+++ b/prompts/Composition/Slide7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce9418614363adfcd1b96b6df3b990d8204ed0a0341c348f9f340d7c128b4900
+size 174070
diff --git a/prompts/Composition/Slide8.png b/prompts/Composition/Slide8.png
new file mode 100644
index 0000000000000000000000000000000000000000..d214f14a6149fd5fdf7a9a55b558b2e74b192359
--- /dev/null
+++ b/prompts/Composition/Slide8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66a589f649600c65b7e808d824322d5a7c36b39675704cd5857fc31ce4f5af7f
+size 180144
diff --git a/prompts/Composition/Slide9.png b/prompts/Composition/Slide9.png
new file mode 100644
index 0000000000000000000000000000000000000000..c2be5efa234c711a7da20cd11f717e797b1e9bf8
--- /dev/null
+++ b/prompts/Composition/Slide9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee514628d3da8c4525853c86d1e8d348de7a0641312a0e3c79fae6b5d73ae11f
+size 454702
diff --git a/prompts/Depth Estimation/1.png b/prompts/Depth Estimation/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62
--- /dev/null
+++ b/prompts/Depth Estimation/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2
+size 48533
diff --git a/prompts/Depth Estimation/1_depth.png b/prompts/Depth Estimation/1_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..e4a416e3cf5ee9e7b46f95c5590905192d50f99f
--- /dev/null
+++ b/prompts/Depth Estimation/1_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b22aa119576ab691bab3db3fdd7eacf53dadc9e4cb3a9bfe4f4cb9c6fc0f6c6
+size 13888
diff --git a/prompts/Depth Estimation/2.png b/prompts/Depth Estimation/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5
--- /dev/null
+++ b/prompts/Depth Estimation/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b
+size 54286
diff --git a/prompts/Depth Estimation/2_depth.png b/prompts/Depth Estimation/2_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..70e9ccb6b44f0284cac2e39bc3e9e4981d5ac373
--- /dev/null
+++ b/prompts/Depth Estimation/2_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:37eacaf9208cf21693ae99802697e8894a9e8cf40cc221c704a50358f14dc954
+size 12257
diff --git a/prompts/Depth Estimation/3.png b/prompts/Depth Estimation/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18
--- /dev/null
+++ b/prompts/Depth Estimation/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00
+size 52593
diff --git a/prompts/Depth Estimation/3_depth.png b/prompts/Depth Estimation/3_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..f39d22d9069072ee46affd295b14f8e02a70ceb6
--- /dev/null
+++ b/prompts/Depth Estimation/3_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a91a47d21378bef0d535f7e33c0185e60cc23baa7ced20bc1ffb028a5d95b5c4
+size 13332
diff --git a/prompts/Depth Estimation/4.png b/prompts/Depth Estimation/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a
--- /dev/null
+++ b/prompts/Depth Estimation/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d
+size 60589
diff --git a/prompts/Depth Estimation/4_depth.png b/prompts/Depth Estimation/4_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..d578e93b900f2e0d74ca15723de3c92b90fc66d2
--- /dev/null
+++ b/prompts/Depth Estimation/4_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0685d0c4755206910cb1b1feea54a1e843cdee9dd140e414c0df56a885b68d85
+size 13447
diff --git a/prompts/Depth Estimation/5.png b/prompts/Depth Estimation/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd
--- /dev/null
+++ b/prompts/Depth Estimation/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b
+size 21984
diff --git a/prompts/Depth Estimation/5_depth.png b/prompts/Depth Estimation/5_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..9790c23cff31b23afe755370fae8c4fbac6316f3
--- /dev/null
+++ b/prompts/Depth Estimation/5_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e9874362d8a1c85b0030590399a5f6388fe69b9dd42ec762313b97d37817eb7
+size 12020
diff --git a/prompts/Depth Estimation/6.png b/prompts/Depth Estimation/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252
--- /dev/null
+++ b/prompts/Depth Estimation/6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547
+size 30704
diff --git a/prompts/Depth Estimation/6_depth.png b/prompts/Depth Estimation/6_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f84aeef983f2a70a1a26033942ffbc1247eeaa5
--- /dev/null
+++ b/prompts/Depth Estimation/6_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:588df6be45864d2164b5e215429f268ba82731540adb07cd4ea47db0ca8f5319
+size 11946
diff --git a/prompts/Depth Estimation/7.png b/prompts/Depth Estimation/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6
--- /dev/null
+++ b/prompts/Depth Estimation/7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7
+size 49450
diff --git a/prompts/Depth Estimation/7_depth.png b/prompts/Depth Estimation/7_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..fe14a639f330ac40b0ec82a7c19e8f61befe9496
--- /dev/null
+++ b/prompts/Depth Estimation/7_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c0edcab180411a6966d899de8f282870293a5d275e58d4a185e3cb31d9ca6b0d
+size 13252
diff --git a/prompts/Depth Estimation/8.png b/prompts/Depth Estimation/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d
--- /dev/null
+++ b/prompts/Depth Estimation/8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38
+size 50877
diff --git a/prompts/Eaten Apples/1.png b/prompts/Eaten Apples/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..d32aa001b22d9a408645e06cc02ad52505e230a2
--- /dev/null
+++ b/prompts/Eaten Apples/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a75364bb67ce5741004e2bb18178b362fd1d4dee12a76d9ae4be2124fb3452a0
+size 199368
diff --git a/prompts/Eaten Apples/10.png b/prompts/Eaten Apples/10.png
new file mode 100644
index 0000000000000000000000000000000000000000..aa963c0a65d4e7605efbabc10244c171172835be
--- /dev/null
+++ b/prompts/Eaten Apples/10.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:05f9235b7c283915d0d81b2423915f05f587b91d691ae0cae6f0bc5b68e84588
+size 142649
diff --git a/prompts/Eaten Apples/2.png b/prompts/Eaten Apples/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..dadf5d493d454a5ec31696919aa2be33eea1d6ab
--- /dev/null
+++ b/prompts/Eaten Apples/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25b08de2de0ac2bcc59be060bf19574931091c9dc6472f8122f7ac1243c59c6f
+size 214103
diff --git a/prompts/Eaten Apples/3.png b/prompts/Eaten Apples/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..17bb70919745f0f0401140da26a4aa67c9224875
--- /dev/null
+++ b/prompts/Eaten Apples/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eee49d97068ac9de19bf6aead9a4b0c88ba9108bc9eb9d19f43a3b5919c88367
+size 212059
diff --git a/prompts/Eaten Apples/4.png b/prompts/Eaten Apples/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..bc6979b7ca710ff50f9af0f26f6089d9a25a0b53
--- /dev/null
+++ b/prompts/Eaten Apples/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:88d401e4c2c2b1b21119b953e230a276305af15d295d0035a221e498665af5b4
+size 212147
diff --git a/prompts/Eaten Apples/5.png b/prompts/Eaten Apples/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..60e4028c1752e1caf2581949529b7c76db2b788b
--- /dev/null
+++ b/prompts/Eaten Apples/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:116e8eb9ecc170c4a00f54a3b7b8996b67cd585932e34f4e2a25f8e589b7ae3d
+size 204197
diff --git a/prompts/Eaten Apples/6.png b/prompts/Eaten Apples/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..3011620226f9165f918dd3e1120a3569e4ef3bfd
--- /dev/null
+++ b/prompts/Eaten Apples/6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c683df4c90c7da5bee98499fbf7233b6ac13fe2480aa9e1d4cb80a25ff9a500
+size 192756
diff --git a/prompts/Eaten Apples/7.png b/prompts/Eaten Apples/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..c890af0c7b3b95c297ff07db0da4134754de58b4
--- /dev/null
+++ b/prompts/Eaten Apples/7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e3ad84f16c9326a9819da7e2c9485705b073a47097f742592ada91c10f706c0
+size 181082
diff --git a/prompts/Eaten Apples/8.png b/prompts/Eaten Apples/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..89545aa186079020196c335da0200c2a1c88c80a
--- /dev/null
+++ b/prompts/Eaten Apples/8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:919a814cf9c604923a3384c26473750d353bac17f3422865b68c1d86e45552f7
+size 167449
diff --git a/prompts/Edge Detection/1.png b/prompts/Edge Detection/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62
--- /dev/null
+++ b/prompts/Edge Detection/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2
+size 48533
diff --git a/prompts/Edge Detection/1_edge.png b/prompts/Edge Detection/1_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..48f7dee44046eb0e7853d81be37870da7f7578fb
--- /dev/null
+++ b/prompts/Edge Detection/1_edge.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4ca5651653eb5ed8e3d08600c936985982f5019ce6f2f2489e82c112ea75686a
+size 30563
diff --git a/prompts/Edge Detection/2.png b/prompts/Edge Detection/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5
--- /dev/null
+++ b/prompts/Edge Detection/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b
+size 54286
diff --git a/prompts/Edge Detection/2_edge.png b/prompts/Edge Detection/2_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..49db076ecece6c9db3bee69a720b8207fadfd8e5
--- /dev/null
+++ b/prompts/Edge Detection/2_edge.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b8a5327db1e20457b8ab478a296e5176f2779c5e4bb2c034e5b6f0183854866b
+size 30437
diff --git a/prompts/Edge Detection/3.png b/prompts/Edge Detection/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18
--- /dev/null
+++ b/prompts/Edge Detection/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00
+size 52593
diff --git a/prompts/Edge Detection/3_edge.png b/prompts/Edge Detection/3_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f234ff898f4fc9f390bc4d3e5fd114662411b51
--- /dev/null
+++ b/prompts/Edge Detection/3_edge.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d12107a3ead99860b853db28ddd7c8cc3fe65966cdb9221f6afd2154dadeb507
+size 32196
diff --git a/prompts/Edge Detection/4.png b/prompts/Edge Detection/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a
--- /dev/null
+++ b/prompts/Edge Detection/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d
+size 60589
diff --git a/prompts/Edge Detection/4_edge.png b/prompts/Edge Detection/4_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..171d43defccfc3a9473faff804b31564483380d6
--- /dev/null
+++ b/prompts/Edge Detection/4_edge.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a4d3730f25b2b7305dfc56875b12135b99c0a93cba8ac1a1ff899b7d68eb8ef
+size 39602
diff --git a/prompts/Edge Detection/5.png b/prompts/Edge Detection/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd
--- /dev/null
+++ b/prompts/Edge Detection/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b
+size 21984
diff --git a/prompts/Edge Detection/5_edge.png b/prompts/Edge Detection/5_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..c2b7f3b15d3e3b4aab01eabebccb8556531d7619
--- /dev/null
+++ b/prompts/Edge Detection/5_edge.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c958bf6a35c685e34ae07216dd112ac4fcedd9c4629d096615333bb69b45d45c
+size 16448
diff --git a/prompts/Edge Detection/6.png b/prompts/Edge Detection/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252
--- /dev/null
+++ b/prompts/Edge Detection/6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547
+size 30704
diff --git a/prompts/Edge Detection/6_edge.png b/prompts/Edge Detection/6_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..9d62a8aa115b5cde4f7a317a744f96673102a8c4
--- /dev/null
+++ b/prompts/Edge Detection/6_edge.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75383c3fb8acb0318fb95c48c43b56df8574ca119e0f9f577066dab8fdb8fca3
+size 36706
diff --git a/prompts/Edge Detection/7.png b/prompts/Edge Detection/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6
--- /dev/null
+++ b/prompts/Edge Detection/7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7
+size 49450
diff --git a/prompts/Edge Detection/7_edge.png b/prompts/Edge Detection/7_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..be7cf97bb9f5f33d331ece657a7f23817bccf235
--- /dev/null
+++ b/prompts/Edge Detection/7_edge.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93a646f6a4fcd48b949aba04dbdeb781b01e0425a32a03d960e7c9617375fe90
+size 29210
diff --git a/prompts/Edge Detection/8.png b/prompts/Edge Detection/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d
--- /dev/null
+++ b/prompts/Edge Detection/8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38
+size 50877
diff --git a/prompts/Emoji/smile1.png b/prompts/Emoji/smile1.png
new file mode 100644
index 0000000000000000000000000000000000000000..069348959a22a84e5993121e89cf401fa576c923
--- /dev/null
+++ b/prompts/Emoji/smile1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8923aa6be860685271dbef2fcc1fc87230dec0dca44d0815f9ccdf1f8d5aea26
+size 21247
diff --git a/prompts/Emoji/smile2.png b/prompts/Emoji/smile2.png
new file mode 100644
index 0000000000000000000000000000000000000000..01948d4b355fdccce546cd367554db3724231938
--- /dev/null
+++ b/prompts/Emoji/smile2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:beadf44be4d267c8f88d684bfd6b02ca13eeff5fb007f2fbfc8e2f52b8459c64
+size 22703
diff --git a/prompts/Emoji/smile3.png b/prompts/Emoji/smile3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a527c61b51f40daa4ad8425342a3918f140143bf
--- /dev/null
+++ b/prompts/Emoji/smile3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:919155477aa8d2d17502cacf0dbdcec8b8feba5faa2065e86b832ac6020e2169
+size 24308
diff --git a/prompts/Emoji/smile4.png b/prompts/Emoji/smile4.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc8f33154f294ae7c6141c6df333645f8e314f80
--- /dev/null
+++ b/prompts/Emoji/smile4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9571b1c938cd2ccb522bdd15cddb4ff93b84c6ac938923a34f1ee01fdb2a002a
+size 24451
diff --git a/prompts/Object Tracking/Picture1.png b/prompts/Object Tracking/Picture1.png
new file mode 100644
index 0000000000000000000000000000000000000000..9ee9959fc01d71202482304f5031bf3cb7fa04db
--- /dev/null
+++ b/prompts/Object Tracking/Picture1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d311a18e7a9c934ce775fa6db6576f9ed27b9fcac82f224626af8b5b074b3c35
+size 733163
diff --git a/prompts/Object Tracking/Picture2.png b/prompts/Object Tracking/Picture2.png
new file mode 100644
index 0000000000000000000000000000000000000000..5f23477cb63d62c8b57fe69360905c617bde0155
--- /dev/null
+++ b/prompts/Object Tracking/Picture2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a2437f5bd694d69a384c82ad2b163eb665b50f2814a4a30cc1719e4436393f
+size 730944
diff --git a/prompts/Object Tracking/Picture3.png b/prompts/Object Tracking/Picture3.png
new file mode 100644
index 0000000000000000000000000000000000000000..580685c71839bb687193ff48f07300fb423ed80b
--- /dev/null
+++ b/prompts/Object Tracking/Picture3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d31cdf20c257094e3624cc032c975a5326d904d53664ed7acfd022c829961e96
+size 723189
diff --git a/prompts/Object Tracking/Picture4.png b/prompts/Object Tracking/Picture4.png
new file mode 100644
index 0000000000000000000000000000000000000000..85b9aa885f9187d6b7124a04356874280a099caf
--- /dev/null
+++ b/prompts/Object Tracking/Picture4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb5e559a8421bd48bcdb2689142b52ad3384a49d8b59c8f00fb65ab583c62cb1
+size 709894
diff --git a/prompts/Object Tracking/Picture5.png b/prompts/Object Tracking/Picture5.png
new file mode 100644
index 0000000000000000000000000000000000000000..f8b0f91b49529756511fd094ced04288c2636137
--- /dev/null
+++ b/prompts/Object Tracking/Picture5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e486056f0c6898326b54410878f1c0deeba2a5c5db7acec9ba5a14617ab59068
+size 690117
diff --git a/prompts/Object Tracking/Picture6.png b/prompts/Object Tracking/Picture6.png
new file mode 100644
index 0000000000000000000000000000000000000000..10e56c3f9db3a014eef054561236356c444d84c3
--- /dev/null
+++ b/prompts/Object Tracking/Picture6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bf16762c2dd316e52721618f96bcc0811118f812427e933fef50f937adaea3d4
+size 671387
diff --git a/prompts/Object Tracking/Picture7.png b/prompts/Object Tracking/Picture7.png
new file mode 100644
index 0000000000000000000000000000000000000000..c41b3438aa20a8a8f6460003206c945100658073
--- /dev/null
+++ b/prompts/Object Tracking/Picture7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:227bf301dc7f06208e897ab3b6092e15e0b32f7ca0ef77a51f8fb81bd258e9a3
+size 654387
diff --git a/prompts/Object Tracking/Picture8.png b/prompts/Object Tracking/Picture8.png
new file mode 100644
index 0000000000000000000000000000000000000000..6de5ddbfcfa353d01d176426271f2fca2c770961
--- /dev/null
+++ b/prompts/Object Tracking/Picture8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e38170d147a836c88782b37d7f2bd60ee051016c6daf53ea879320f0967c818c
+size 642514
diff --git a/prompts/Object Tracking/Picture9.png b/prompts/Object Tracking/Picture9.png
new file mode 100644
index 0000000000000000000000000000000000000000..117698ee19e804cc082023da4dff47427f68532f
--- /dev/null
+++ b/prompts/Object Tracking/Picture9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:14dc7b90ae871b5be240e5775e1bd0be2ae18e73ef5ac87e7a96a928613c03b5
+size 622249
diff --git a/prompts/Outpainting/2.png b/prompts/Outpainting/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..b77b519a6a263166ce81d41efc2a17e528c5dafc
--- /dev/null
+++ b/prompts/Outpainting/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d73f108e47033bf08b8f5b4558e72c73231f5cb06e06c6283b782fe260739f1
+size 440489
diff --git a/prompts/Outpainting/3.png b/prompts/Outpainting/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e2d1fdfecf4447a2475ed09e23035f11d161e95
--- /dev/null
+++ b/prompts/Outpainting/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fba02eef20fd3a32ff01fc43bae60811185ce34cac343699a4e2b7383a81e8ef
+size 829260
diff --git a/prompts/Outpainting/4.png b/prompts/Outpainting/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..0375a22e663f9855153362011d827ee04dda1011
--- /dev/null
+++ b/prompts/Outpainting/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e36507bd52fc628b911f8e3d2806549344fa51a41821bb8a31b772296c5c8207
+size 1295391
diff --git a/prompts/Outpainting/5.png b/prompts/Outpainting/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..48e5357ccdadd953ecbeb7e45c832f44f063f9d1
--- /dev/null
+++ b/prompts/Outpainting/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42b3a730e6ad1adf868aac31dfadbd2773df5850b3602dcb84e4c10b57e3f4a1
+size 1980696
diff --git a/prompts/Segmentation/1.png b/prompts/Segmentation/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62
--- /dev/null
+++ b/prompts/Segmentation/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2
+size 48533
diff --git a/prompts/Segmentation/1_seg.png b/prompts/Segmentation/1_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..7ea7e500d9d2dc981deb196fb5693acf96459636
--- /dev/null
+++ b/prompts/Segmentation/1_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e2642904a2acd1bbf4cc099ccf3bb0d0f7ee3370079eda0a21d04df75923a2dd
+size 1234
diff --git a/prompts/Segmentation/2.png b/prompts/Segmentation/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5
--- /dev/null
+++ b/prompts/Segmentation/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b
+size 54286
diff --git a/prompts/Segmentation/2_seg.png b/prompts/Segmentation/2_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..a38cbb9400b37bcda2aa0d92c022a2883512728b
--- /dev/null
+++ b/prompts/Segmentation/2_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f1070ec66a0b1295a20871f0ef04e2f150e708b8fedcd915733b62978cbfcd0d
+size 2243
diff --git a/prompts/Segmentation/3.png b/prompts/Segmentation/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18
--- /dev/null
+++ b/prompts/Segmentation/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00
+size 52593
diff --git a/prompts/Segmentation/3_seg.png b/prompts/Segmentation/3_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f3645ca6724fce3c749631221a657ae42dba25c
--- /dev/null
+++ b/prompts/Segmentation/3_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf8dc0e197459bdef57d3246e3cd5355f7a8c6d5f4822eb05fe1a8813d3f71b5
+size 1869
diff --git a/prompts/Segmentation/4.png b/prompts/Segmentation/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a
--- /dev/null
+++ b/prompts/Segmentation/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d
+size 60589
diff --git a/prompts/Segmentation/4_seg.png b/prompts/Segmentation/4_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..358cb0caa8e26115253f7e579484ec207ecf11bf
--- /dev/null
+++ b/prompts/Segmentation/4_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:101793eff1e1cb3613cff94fd1cbd3e65ce9407a29dd9efd8ffa78f90443fa72
+size 3129
diff --git a/prompts/Segmentation/5.png b/prompts/Segmentation/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd
--- /dev/null
+++ b/prompts/Segmentation/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b
+size 21984
diff --git a/prompts/Segmentation/5_seg.png b/prompts/Segmentation/5_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..4356240ef23bfbdff9edbf1a5ea36d75bae5e93a
--- /dev/null
+++ b/prompts/Segmentation/5_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fda0b493538ed52658c31498c678ce098229bd9c57c3e76db9096298f2b2309d
+size 1814
diff --git a/prompts/Segmentation/6.png b/prompts/Segmentation/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252
--- /dev/null
+++ b/prompts/Segmentation/6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547
+size 30704
diff --git a/prompts/Segmentation/6_seg.png b/prompts/Segmentation/6_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..d0aa3baf6aed286426f0dd6a6190ce0c094333df
--- /dev/null
+++ b/prompts/Segmentation/6_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:10e8d222b3db095de09124c8b129ac63da360776d588b1fb5027b31fa9b5d1d0
+size 1684
diff --git a/prompts/Segmentation/7.png b/prompts/Segmentation/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6
--- /dev/null
+++ b/prompts/Segmentation/7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7
+size 49450
diff --git a/prompts/Segmentation/7_seg.png b/prompts/Segmentation/7_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..f768b93c59c7e66c0617d1160a5381080b8b44f4
--- /dev/null
+++ b/prompts/Segmentation/7_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c11400a1bc6c696c2f479165e5ffe7d0be2212c252b311b29eeae2d111c927a
+size 1887
diff --git a/prompts/Segmentation/8.png b/prompts/Segmentation/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d
--- /dev/null
+++ b/prompts/Segmentation/8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38
+size 50877
diff --git a/prompts/Surface Normal/1.png b/prompts/Surface Normal/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62
--- /dev/null
+++ b/prompts/Surface Normal/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2
+size 48533
diff --git a/prompts/Surface Normal/1_surfave_norm.png b/prompts/Surface Normal/1_surfave_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9865e90d1661296e48a27338eef101aa5b26ba3
--- /dev/null
+++ b/prompts/Surface Normal/1_surfave_norm.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f40b8f98c35392de51e72db501b26dd04f3fa057287839df4076f0402c2fe05b
+size 42875
diff --git a/prompts/Surface Normal/2.png b/prompts/Surface Normal/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5
--- /dev/null
+++ b/prompts/Surface Normal/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b
+size 54286
diff --git a/prompts/Surface Normal/2_surface_norm.png b/prompts/Surface Normal/2_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ca2da625560053b50957186fc30d9c9249333b5
--- /dev/null
+++ b/prompts/Surface Normal/2_surface_norm.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e52e3619a44c896fbfcc4cd5a60435fb2187c907da2122c7a5584363e8b7109
+size 45167
diff --git a/prompts/Surface Normal/3.png b/prompts/Surface Normal/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18
--- /dev/null
+++ b/prompts/Surface Normal/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00
+size 52593
diff --git a/prompts/Surface Normal/3_surface_norm.png b/prompts/Surface Normal/3_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..a17702432fea0029d9ad22445dac6b9ba1a2b4af
--- /dev/null
+++ b/prompts/Surface Normal/3_surface_norm.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f2bb1d3d41c0f72faf0c4fca3c0030021650d0f36707cdfdb381d3ef9a5edb6b
+size 49907
diff --git a/prompts/Surface Normal/4.png b/prompts/Surface Normal/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a
--- /dev/null
+++ b/prompts/Surface Normal/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d
+size 60589
diff --git a/prompts/Surface Normal/4_surface_norm.png b/prompts/Surface Normal/4_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..b92d9199042136de71381a3a0d3463d070f96728
--- /dev/null
+++ b/prompts/Surface Normal/4_surface_norm.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95d833513daa9de65dc68f3eeb4e589ec43d2e42e7305c789f4235f6213c5338
+size 46127
diff --git a/prompts/Surface Normal/5.png b/prompts/Surface Normal/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd
--- /dev/null
+++ b/prompts/Surface Normal/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b
+size 21984
diff --git a/prompts/Surface Normal/5_surface_norm.png b/prompts/Surface Normal/5_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..d30c9ee11141aceff4713cc580a857d876a7c9f8
--- /dev/null
+++ b/prompts/Surface Normal/5_surface_norm.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9898e670f6c5e5bdd4bb4836b7ad9632ea90e2e91d57a524f22214c9cf6ef2cc
+size 34050
diff --git a/prompts/Surface Normal/6.png b/prompts/Surface Normal/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252
--- /dev/null
+++ b/prompts/Surface Normal/6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547
+size 30704
diff --git a/prompts/Surface Normal/6_surface_norm.png b/prompts/Surface Normal/6_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..8188e4041605d668f198a92ba661882417302e51
--- /dev/null
+++ b/prompts/Surface Normal/6_surface_norm.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b87d3709a62bfced6176861c06102bb1fe45583e92fadacb5a38562bba34339
+size 45259
diff --git a/prompts/Surface Normal/7.png b/prompts/Surface Normal/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6
--- /dev/null
+++ b/prompts/Surface Normal/7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7
+size 49450
diff --git a/prompts/Surface Normal/7_surface_norm.png b/prompts/Surface Normal/7_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..d8955586a49136487bbd689017efa9d73ad2544a
--- /dev/null
+++ b/prompts/Surface Normal/7_surface_norm.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84b9147bdcfa2c56c05bafc73d025efc251ba9b687dc9a83a00e2f9633c7a172
+size 44708
diff --git a/prompts/Surface Normal/8.png b/prompts/Surface Normal/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d
--- /dev/null
+++ b/prompts/Surface Normal/8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38
+size 50877
diff --git a/prompts/Synthetic Object Replication/Slide1.png b/prompts/Synthetic Object Replication/Slide1.png
new file mode 100644
index 0000000000000000000000000000000000000000..aaa3ff562f23d81a718d76104668bd190ea6ebc7
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d00ed3bc13809d3c1d54fbf851c334e3a0e6f088a4c7420c73af078e5517e05
+size 229453
diff --git a/prompts/Synthetic Object Replication/Slide2.png b/prompts/Synthetic Object Replication/Slide2.png
new file mode 100644
index 0000000000000000000000000000000000000000..3de22bebc4cd4cc363f223ef323d48b0c75b3719
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29aeaa8d76977173a58a957c463cb794a5a8d780b092706836fea4832b787f35
+size 235912
diff --git a/prompts/Synthetic Object Replication/Slide3.png b/prompts/Synthetic Object Replication/Slide3.png
new file mode 100644
index 0000000000000000000000000000000000000000..e45d49883dc9528572525fc3b7e3ffff941eb68f
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0093baa4b93b0f75872637787cda5412fc371195b3f91188c1bf6e03dfdd49d2
+size 233967
diff --git a/prompts/Synthetic Object Replication/Slide4.png b/prompts/Synthetic Object Replication/Slide4.png
new file mode 100644
index 0000000000000000000000000000000000000000..97991be222d6e76c92a53ffb76a3e307fff53641
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fcce41ad98b8a7a3b7f1153b48813e936a06590bb716c184fbe012045063f815
+size 238995
diff --git a/prompts/Synthetic Object Replication/Slide5.png b/prompts/Synthetic Object Replication/Slide5.png
new file mode 100644
index 0000000000000000000000000000000000000000..22565ba0e7b64c76e00fbb0519570f6ac35980b9
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a686955f717be1a514f68786268394eea86485b3aba333913d62c6548c8417d
+size 236567
diff --git a/prompts/Synthetic Object Replication/Slide6.png b/prompts/Synthetic Object Replication/Slide6.png
new file mode 100644
index 0000000000000000000000000000000000000000..6bcb3367c8db69f98a63a9995e5147458d66ca11
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f1b4e759376a1c26a837ff59a918ffd85a1689cb25064b39e8f7976aacb3ba7
+size 240757
diff --git a/prompts/Synthetic Object Replication/Slide7.png b/prompts/Synthetic Object Replication/Slide7.png
new file mode 100644
index 0000000000000000000000000000000000000000..c5d7852cadabb41ae06e051d81d864916ed35abd
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1132a1a1f41a406069b628dcf2d81995ef48bc0a493d3e501c266c617d722499
+size 238630
diff --git a/prompts/Synthetic Object Replication/Slide8.png b/prompts/Synthetic Object Replication/Slide8.png
new file mode 100644
index 0000000000000000000000000000000000000000..953743749cbd61eaa0d6371dadee67cccf4156a4
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de4e4c3b27e1d7691b697704c6506081e2cd5903929dca66deab53c4be08d886
+size 235972
diff --git a/prompts/Synthetic Object Replication/Slide9.png b/prompts/Synthetic Object Replication/Slide9.png
new file mode 100644
index 0000000000000000000000000000000000000000..7430d331a7cb500dbe9852d0e5ad4995b3d132d8
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2dbf3f768e24d4e46c04bad22f41798d717a15c566ad3d6b0d0813e23feb995f
+size 233023
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..18ebbfe08f54d5ba0a5ca1c4f58220ce4da99c4d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,26 @@
+numpy
+scipy
+matplotlib
+seaborn
+jupyter
+tqdm
+pillow
+--extra-index-url https://download.pytorch.org/whl/cu118
+transformers==4.34.1
+torch==2.0.1
+einops
+absl-py
+ml_collections
+requests
+mlxu==0.1.11
+pydantic
+fastapi
+uvicorn
+gradio
+fastapi
+uvicorn
+opencv-python-headless
+scikit-video
+scikit-image
+natsort
+accelerate
diff --git a/torch_vqvae_model.py b/torch_vqvae_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe6053b36c5410da5f3a914e1a6644f840734311
--- /dev/null
+++ b/torch_vqvae_model.py
@@ -0,0 +1,257 @@
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import einops
+from einops.layers.torch import Rearrange
+
+
+def normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+def swish(x):
+ return x*torch.sigmoid(x)
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None, activation_fn="relu"):
+ super(ResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = normalize(in_channels)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.norm2 = normalize(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ if self.in_channels != self.out_channels:
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
+ self.activation_fn = activation_fn
+ if activation_fn=="relu":
+ self.actn = nn.ReLU()
+
+
+ def forward(self, x_in):
+ x = x_in
+ x = self.norm1(x)
+ if self.activation_fn=="relu":
+ x = self.actn(x)
+ elif self.activation_fn=="swish":
+ x = swish(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ if self.activation_fn=="relu":
+ x = self.actn(x)
+ elif self.activation_fn=="swish":
+ x = swish(x)
+ x = self.conv2(x)
+ if self.in_channels != self.out_channels:
+ x_in = self.conv_out(x_in)
+
+ return x + x_in
+
+class Encoder(nn.Module):
+ def __init__(self, ):
+ super().__init__()
+
+ self.filters = 128
+ self.num_res_blocks = 2
+ self.ch_mult = [1,1,2,2,4]
+ self.in_ch_mult = (1,)+tuple(self.ch_mult)
+ self.embedding_dim = 32
+ self.conv_downsample = False
+
+ self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False)
+ blocks = []
+ for i in range(len(self.ch_mult)):
+ block_in_ch = self.filters * self.in_ch_mult[i]
+ block_out_ch = self.filters * self.ch_mult[i]
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
+ block_in_ch = block_out_ch
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
+ self.norm1 = normalize(block_in_ch)
+ self.conv2 = nn.Conv2d(block_in_ch, self.embedding_dim, kernel_size=1, stride=1, padding=0)
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ for i in range(len(self.ch_mult)):
+ for j in range(self.num_res_blocks):
+ x = self.blocks[i*2+j](x)
+
+ if i < len(self.ch_mult) -1:
+ x = torch.nn.functional.avg_pool2d(x, (2,2),(2,2))
+
+ x = self.blocks[-2](x)
+ x = self.blocks[-1](x)
+
+ x = self.norm1(x)
+ x = swish(x)
+ x = self.conv2(x)
+ return x
+
+class VectorQuantizer(nn.Module):
+ def __init__(self, codebook_size=8192, emb_dim=32, beta=None):
+ super(VectorQuantizer, self).__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+ self.beta=0.0
+ self.z_dim = emb_dim
+
+ def forward(self, z):
+ # preprocess
+
+ b, c, h, w = z.size()
+ flatten = z.permute(0, 2, 3, 1).reshape(-1, c)
+ codebook = self.embedding.weight
+ with torch.no_grad():
+ tokens = torch.cdist(flatten, codebook).argmin(dim=1)
+ quantized = F.embedding(tokens,
+ codebook).view(b, h, w, c).permute(0, 3, 1, 2)
+
+ # compute loss
+ codebook_loss = F.mse_loss(quantized, z.detach())
+ commitment_loss = F.mse_loss(quantized.detach(), z)
+ loss = codebook_loss + self.beta * commitment_loss
+
+ # perplexity
+ counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype)
+ # dist.all_reduce(counts)
+ p = counts / counts.sum()
+ perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10)))
+
+ # postprocess
+ tokens = tokens.view(b, h, w)
+ quantized = z + (quantized - z).detach()
+
+ # quantized_2 = self.get_codebook_feat(tokens, (b, h, w, c))
+
+ return quantized, tokens, loss, perplexity
+
+
+ def get_codebook_feat(self, indices, shape=None):
+ # input indices: batch*token_num -> (batch*token_num)*1
+ # shape: batch, height, width, channel
+ indices = indices.view(-1,1)
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
+ min_encodings.scatter_(1, indices, 1)
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None: # reshape back to match original input shape
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class Decoder(nn.Module):
+ def __init__(self,):
+ super().__init__()
+ self.filters = 128
+ self.num_res_blocks = 2
+ self.ch_mult = [1,1,2,2,4]
+ self.in_ch_mult = (1,)+tuple(self.ch_mult)
+ self.embedding_dim =32
+ self.out_channels = 3
+ self.in_channels = self.embedding_dim
+ self.conv_downsample = False
+
+ self.conv1 = nn.Conv2d(32, 512, kernel_size=3, stride=1, padding=1)
+ blocks = []
+ block_in_ch = self.filters * self.ch_mult[-1]
+ block_out_ch = self.filters * self.ch_mult[-1]
+ #blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
+ upsample_conv_layers = []
+ for i in reversed(range(len(self.ch_mult))):
+ block_out_ch = self.filters * self.ch_mult[i]
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
+ block_in_ch = block_out_ch
+ if i > 0:
+ upsample_conv_layers.append(nn.Conv2d(block_in_ch, block_out_ch*4, kernel_size=3, stride=1, padding=1))
+
+ self.upsample = Rearrange("b h w (h2 w2 c) -> b (h h2) (w w2) c", h2=2, w2=2)
+ self.norm1 = normalize(block_in_ch)
+ # self.act_fn
+ self.conv6 = nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)
+ self.blocks = nn.ModuleList(blocks)
+ self.up_convs = nn.ModuleList(upsample_conv_layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.blocks[0](x)
+ x = self.blocks[1](x)
+ for i in range(len(self.ch_mult)):
+ for j in range(self.num_res_blocks):
+ x = self.blocks[2+i*2+j](x)
+ if i < len(self.ch_mult)-1:
+ x = self.up_convs[i](x)
+ #print("pre: x.size()",x.size())
+ x = x.permute(0,2,3,1)
+ x = self.upsample(x)
+ x = x.permute(0,3,1,2)
+ #print("post: x.size()", x.size())
+ x = self.norm1(x)
+ x = swish(x)
+ x = self.conv6(x)
+ return x
+
+
+class VQVAE(nn.Module):
+ def __init__(self, ):
+ super().__init__()
+ self.encoder = Encoder()
+ self.quantizer = VectorQuantizer()
+ self.decoder = Decoder()
+
+ def forward(self, x):
+ x = self.encoder(x)
+ quant,tokens, loss, perplexity = self.quantizer(x)
+ x = self.decoder(quant)
+ return x
+
+ def tokenize(self, x):
+ batch_shape = x.shape[:-3]
+ x = x.reshape(-1, *x.shape[-3:])
+ x = self.encoder(x)
+ quant,tokens, loss, perplexity = self.quantizer(x)
+ return tokens.reshape(*batch_shape, *tokens.shape[1:])
+
+ def decode(self, tokens):
+ tokens = einops.rearrange(tokens, 'b ... -> b (...)')
+ b = tokens.shape[0]
+ if tokens.shape[-1] == 256:
+ hw = 16
+ elif tokens.shape[-1] == 224:
+ hw = 14
+ else:
+ raise ValueError("Invalid tokens shape")
+ quant = self.quantizer.get_codebook_feat(tokens, (b, hw, hw, 32))
+ x = self.decoder(quant)
+ return x
+
+
+class VAEDecoder(nn.Module):
+ def __init__(self, ):
+ super().__init__()
+ self.quantizer = VectorQuantizer()
+ self.decoder = Decoder()
+
+ def forward(self, x):
+ quant = self.quantizer.get_codebook_feat(x,(1,14,14,32))
+ x = self.decoder(quant)
+ return x
+
+
+def get_tokenizer():
+ checkpoint_path = os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), "xh_ckpt.pth"
+ )
+ torch_state_dict = torch.load(checkpoint_path)
+ net = VQVAE()
+ net.load_state_dict(torch_state_dict)
+ return net
+
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b059a22e2fa8786ebba48dab800318b362507715
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,296 @@
+import os
+from multiprocessing import Pool
+import numpy as np
+import random
+from PIL import Image
+import re
+import cv2
+import glob
+from natsort import natsorted
+
+
+class MultiProcessImageSaver(object):
+
+ def __init__(self, n_workers=1):
+ self.pool = Pool(n_workers)
+
+ def __call__(self, images, output_files, resizes=None):
+ if resizes is None:
+ resizes = [None for _ in range(len(images))]
+ return self.pool.imap(
+ self.save_image,
+ zip(images, output_files, resizes),
+ )
+
+ def close(self):
+ self.pool.close()
+ self.pool.join()
+
+ @staticmethod
+ def save_image(args):
+ image, filename, resize = args
+ image = Image.fromarray(image)
+ if resize is not None:
+ image = image.resize(tuple(resize))
+ image.save(filename)
+
+
+def list_dir_with_full_path(path):
+ return [os.path.join(path, f) for f in os.listdir(path)]
+
+
+def find_all_files_in_dir(path):
+ files = []
+ for root, _, files in os.walk(path):
+ for file in files:
+ files.append(os.path.join(root, file))
+ return files
+
+
+def is_image(path):
+ return (
+ path.endswith('.jpg')
+ or path.endswith('.png')
+ or path.endswith('.jpeg')
+ or path.endswith('.JPG')
+ or path.endswith('.PNG')
+ or path.endswith('.JPEG')
+ )
+
+
+def is_video(path):
+ return (
+ path.endswith('.mp4')
+ or path.endswith('.avi')
+ or path.endswith('.MP4')
+ or path.endswith('.AVI')
+ or path.endswith('.webm')
+ or path.endswith('.WEBM')
+ or path.endswith('.mkv')
+ or path.endswith('.MVK')
+ )
+
+
+def random_square_crop(img, random_generator=None):
+ # If no random generator is provided, use numpy's default
+ if random_generator is None:
+ random_generator = np.random.default_rng()
+
+ # Get the width and height of the image
+ width, height = img.size
+
+ # Determine the shorter side
+ min_size = min(width, height)
+
+ # Randomly determine the starting x and y coordinates for the crop
+ if width > height:
+ left = random_generator.integers(0, width - min_size)
+ upper = 0
+ else:
+ left = 0
+ upper = random_generator.integers(0, height - min_size)
+
+ # Calculate the ending x and y coordinates for the crop
+ right = left + min_size
+ lower = upper + min_size
+
+ # Crop the image
+ return img.crop((left, upper, right, lower))
+
+
+def read_image_to_tensor(path, center_crop=1.0):
+ pil_im = Image.open(path).convert('RGB')
+ if center_crop < 1.0:
+ width, height = pil_im.size
+ pil_im = pil_im.crop((
+ int((1 - center_crop) * height / 2), int((1 + center_crop) * height / 2),
+ int((1 - center_crop) * width / 2), int((1 + center_crop) * width / 2),
+ ))
+ input_img = pil_im.resize((256, 256))
+ input_img = np.array(input_img) / 255.0
+ input_img = input_img.astype(np.float32)
+ return input_img
+
+
+def match_mulitple_path(root_dir, regex):
+ videos = []
+ for root, _, files in os.walk(root_dir):
+ for file in files:
+ videos.append(os.path.join(root, file))
+
+ videos = [v for v in videos if not v.split('/')[-1].startswith('.')]
+
+ grouped_path = {}
+ for r in regex:
+ r = re.compile(r)
+ for v in videos:
+ matched = r.findall(v)
+ if len(matched) > 0:
+ groups = matched[0]
+ if groups not in grouped_path:
+ grouped_path[groups] = []
+ grouped_path[groups].append(v)
+
+ grouped_path = {
+ k: tuple(v) for k, v in grouped_path.items()
+ if len(v) == len(regex)
+ }
+ return list(grouped_path.values())
+
+
+def randomly_subsample_frame_indices(length, n_frames, max_stride=30, random_start=True):
+ assert length >= n_frames
+ max_stride = min(
+ (length - 1) // (n_frames - 1),
+ max_stride
+ )
+ stride = np.random.randint(1, max_stride + 1)
+ if random_start:
+ start = np.random.randint(0, length - (n_frames - 1) * stride)
+ else:
+ start = 0
+ return np.arange(n_frames) * stride + start
+
+
+def read_frames_from_dir(dir_path, n_frames, stride, random_start=True, center_crop=1.0):
+ files = [os.path.join(dir_path, x) for x in os.listdir(dir_path)]
+ files = natsorted([x for x in files if is_image(x)])
+
+ total_frames = len(files)
+
+ if total_frames < n_frames:
+ return None
+
+ max_stride = (total_frames - 1) // (n_frames - 1)
+ stride = min(max_stride, stride)
+
+ if random_start:
+ start = np.random.randint(0, total_frames - (n_frames - 1) * stride)
+ else:
+ start = 0
+ frame_indices = np.arange(n_frames) * stride + start
+
+ frames = []
+ for frame_index in sorted(frame_indices):
+ # Check if the frame_index is valid
+ frames.append(read_image_to_tensor(files[frame_index], center_crop=center_crop))
+ if len(frames) < n_frames:
+ return None
+ frames = np.stack(frames, axis=0)
+ return frames
+
+
+def read_frames_from_video(video_path, n_frames, stride, random_start=True, center_crop=1.0):
+
+ frames = []
+ cap = cv2.VideoCapture(video_path)
+
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ if total_frames < n_frames:
+ cap.release()
+ return None
+
+ max_stride = (total_frames - 1) // (n_frames - 1)
+ stride = min(max_stride, stride)
+
+ if random_start:
+ start = np.random.randint(0, total_frames - (n_frames - 1) * stride)
+ else:
+ start = 0
+ frame_indices = np.arange(n_frames) * stride + start
+
+ for frame_index in sorted(frame_indices):
+ # Check if the frame_index is valid
+ if 0 <= frame_index < total_frames:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
+ ret, frame = cap.read()
+ if ret:
+ if center_crop < 1.0:
+ height, width, _ = frame.shape
+ frame = frame[
+ int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2),
+ int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2),
+ :
+ ]
+ frame = cv2.resize(frame, (256, 256))
+
+ frames.append(frame)
+
+ else:
+ print(f"Frame index {frame_index} is out of bounds. Skipping...")
+
+ cap.release()
+ if len(frames) < n_frames:
+ return None
+ frames = np.stack(frames, axis=0).astype(np.float32) / 255.0
+
+ # From BGR to RGB
+ return np.stack(
+ [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1
+ )
+
+
+def read_all_frames_from_video(video_path, center_crop=1.0):
+
+ frames = []
+ cap = cv2.VideoCapture(video_path)
+
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+
+ for frame_index in range(total_frames):
+ # Check if the frame_index is valid
+ if 0 <= frame_index < total_frames:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
+ ret, frame = cap.read()
+ if ret:
+ if center_crop < 1.0:
+ height, width, _ = frame.shape
+ frame = frame[
+ int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2),
+ int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2),
+ :
+ ]
+ frames.append(cv2.resize(frame, (256, 256)))
+ else:
+ print(f"Frame index {frame_index} is out of bounds. Skipping...")
+
+ cap.release()
+ if len(frames) == 0:
+ return None
+ frames = np.stack(frames, axis=0).astype(np.float32) / 255.0
+ # From BGR to RGB
+ return np.stack(
+ [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1
+ )
+
+
+def read_max_span_frames_from_video(video_path, n_frames):
+ frames = []
+ cap = cv2.VideoCapture(video_path)
+
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ if total_frames < n_frames:
+ cap.release()
+ return None
+ stride = (total_frames - 1) // (n_frames - 1)
+ frame_indices = np.arange(n_frames) * stride
+
+ frames = []
+ for frame_index in frame_indices:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
+ ret, frame = cap.read()
+ if ret:
+ frames.append(cv2.resize(frame, (256, 256)))
+
+ cap.release()
+ if len(frames) < n_frames:
+ return None
+
+ frames = np.stack(frames, axis=0).astype(np.float32) / 255.0
+ # From BGR to RGB
+ return np.stack(
+ [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1
+ )
+
diff --git a/vqvae/.DS_Store b/vqvae/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..38734ca2de71d90578b12a191d5ff30a57f26d5c
Binary files /dev/null and b/vqvae/.DS_Store differ
diff --git a/vqvae/__init__.py b/vqvae/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9b32aee9a68b1192ae0be7214ca92f35defd717
--- /dev/null
+++ b/vqvae/__init__.py
@@ -0,0 +1,25 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__version__ = "0.0.1"
+
+# from .modeling_ema import EMAModel
+# from .modeling_maskgit_vqgan import MaskGitVQGAN
+# from .modeling_movq import MOVQ
+# from .modeling_paella_vq import PaellaVQModel
+# from .modeling_utils import VQGANModel
+# from .modeling_transformer import MaskGitTransformer, MaskGiTUViT
+# from .pipeline_muse import PipelineMuse, PipelineMuseInpainting
+# from .sampling import get_mask_chedule
diff --git a/vqvae/logging.py b/vqvae/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..65814a82380e47e54434c4be97026141772f7298
--- /dev/null
+++ b/vqvae/logging.py
@@ -0,0 +1,338 @@
+# coding=utf-8
+# Copyright 2023 Optuna, Hugging Face
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Logging utilities."""
+
+import logging
+import os
+import sys
+import threading
+from logging import CRITICAL # NOQA
+from logging import DEBUG # NOQA
+from logging import ERROR # NOQA
+from logging import FATAL # NOQA
+from logging import INFO # NOQA
+from logging import NOTSET # NOQA
+from logging import WARN # NOQA
+from logging import WARNING # NOQA
+from typing import Optional
+
+from tqdm import auto as tqdm_lib
+
+_lock = threading.Lock()
+_default_handler: Optional[logging.Handler] = None
+
+log_levels = {
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+ "critical": logging.CRITICAL,
+}
+
+_default_log_level = logging.WARNING
+
+_tqdm_active = True
+
+
+def _get_default_logging_level():
+ """
+ If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
+ not - fall back to `_default_log_level`
+ """
+ env_level_str = os.getenv("muse_VERBOSITY", None)
+ if env_level_str:
+ if env_level_str in log_levels:
+ return log_levels[env_level_str]
+ else:
+ logging.getLogger().warning(
+ f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }"
+ )
+ return _default_log_level
+
+
+def _get_library_name() -> str:
+ return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> logging.Logger:
+ return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if _default_handler:
+ # This library has already configured the library root logger.
+ return
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
+ _default_handler.flush = sys.stderr.flush
+
+ # Apply our default configuration to the library root logger.
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_get_default_logging_level())
+ library_root_logger.propagate = False
+
+
+def _reset_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if not _default_handler:
+ return
+
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.removeHandler(_default_handler)
+ library_root_logger.setLevel(logging.NOTSET)
+ _default_handler = None
+
+
+def get_log_levels_dict():
+ return log_levels
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+ """
+ Return a logger with the specified name.
+
+ This function is not supposed to be directly accessed unless you are writing a custom muse module.
+ """
+
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+ return logging.getLogger(name)
+
+
+def get_verbosity() -> int:
+ """
+ Return the current level for the 🤗 muse' root logger as an int.
+
+ Returns:
+ `int`: The logging level.
+
+
+
+ 🤗 muse has following logging levels:
+
+ - 50: `muse.logging.CRITICAL` or `muse.logging.FATAL`
+ - 40: `muse.logging.ERROR`
+ - 30: `muse.logging.WARNING` or `muse.logging.WARN`
+ - 20: `muse.logging.INFO`
+ - 10: `muse.logging.DEBUG`
+
+ """
+
+ _configure_library_root_logger()
+ return _get_library_root_logger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity: int) -> None:
+ """
+ Set the verbosity level for the 🤗 muse' root logger.
+
+ Args:
+ verbosity (`int`):
+ Logging level, e.g., one of:
+
+ - `muse.logging.CRITICAL` or `muse.logging.FATAL`
+ - `muse.logging.ERROR`
+ - `muse.logging.WARNING` or `muse.logging.WARN`
+ - `muse.logging.INFO`
+ - `muse.logging.DEBUG`
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().setLevel(verbosity)
+
+
+def set_verbosity_info():
+ """Set the verbosity to the `INFO` level."""
+ return set_verbosity(INFO)
+
+
+def set_verbosity_warning():
+ """Set the verbosity to the `WARNING` level."""
+ return set_verbosity(WARNING)
+
+
+def set_verbosity_debug():
+ """Set the verbosity to the `DEBUG` level."""
+ return set_verbosity(DEBUG)
+
+
+def set_verbosity_error():
+ """Set the verbosity to the `ERROR` level."""
+ return set_verbosity(ERROR)
+
+
+def disable_default_handler() -> None:
+ """Disable the default handler of the HuggingFace muse' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().removeHandler(_default_handler)
+
+
+def enable_default_handler() -> None:
+ """Enable the default handler of the HuggingFace muse' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().addHandler(_default_handler)
+
+
+def add_handler(handler: logging.Handler) -> None:
+ """adds a handler to the HuggingFace muse' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None
+ _get_library_root_logger().addHandler(handler)
+
+
+def remove_handler(handler: logging.Handler) -> None:
+ """removes given handler from the HuggingFace muse' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None and handler not in _get_library_root_logger().handlers
+ _get_library_root_logger().removeHandler(handler)
+
+
+def disable_propagation() -> None:
+ """
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = False
+
+
+def enable_propagation() -> None:
+ """
+ Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent
+ double logging if the root logger has been configured.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = True
+
+
+def enable_explicit_format() -> None:
+ """
+ Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows:
+ ```
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
+ ```
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
+ handler.setFormatter(formatter)
+
+
+def reset_format() -> None:
+ """
+ Resets the formatting for HuggingFace muse' loggers.
+
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ handler.setFormatter(None)
+
+
+def warning_advice(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this
+ warning will not be printed
+ """
+ no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False)
+ if no_advisory_warnings:
+ return
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_advice = warning_advice
+
+
+class EmptyTqdm:
+ """Dummy tqdm which doesn't do anything."""
+
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
+ self._iterator = args[0] if args else None
+
+ def __iter__(self):
+ return iter(self._iterator)
+
+ def __getattr__(self, _):
+ """Return empty function."""
+
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
+ return
+
+ return empty_fn
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ return
+
+
+class _tqdm_cls:
+ def __call__(self, *args, **kwargs):
+ if _tqdm_active:
+ return tqdm_lib.tqdm(*args, **kwargs)
+ else:
+ return EmptyTqdm(*args, **kwargs)
+
+ def set_lock(self, *args, **kwargs):
+ self._lock = None
+ if _tqdm_active:
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
+
+ def get_lock(self):
+ if _tqdm_active:
+ return tqdm_lib.tqdm.get_lock()
+
+
+tqdm = _tqdm_cls()
+
+
+def is_progress_bar_enabled() -> bool:
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
+ global _tqdm_active
+ return bool(_tqdm_active)
+
+
+def enable_progress_bar():
+ """Enable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = True
+
+
+def disable_progress_bar():
+ """Disable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = False
diff --git a/vqvae/modeling_utils.py b/vqvae/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdbdb68a7c2f154c670fe40950c035fad06e4691
--- /dev/null
+++ b/vqvae/modeling_utils.py
@@ -0,0 +1,1171 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import inspect
+import json
+import os
+from collections import OrderedDict
+from functools import partial
+from pathlib import PosixPath
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import accelerate
+import numpy as np
+import torch
+from accelerate.utils import set_module_tensor_to_device
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import (
+ EntryNotFoundError,
+ RepositoryNotFoundError,
+ RevisionNotFoundError,
+)
+from requests import HTTPError
+from torch import Tensor, device
+
+from . import __version__, logging
+
+logger = logging.get_logger(__name__)
+
+
+hf_cache_home = os.path.expanduser(
+ os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
+)
+default_cache_path = os.path.join(hf_cache_home, "muse")
+
+
+CONFIG_NAME = "config.json"
+WEIGHTS_NAME = "pytorch_model.bin"
+SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
+HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
+MUSE_CACHE = default_cache_path
+MUSE_DYNAMIC_MODULE_NAME = "myse_modules"
+HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
+
+
+_LOW_CPU_MEM_USAGE_DEFAULT = True
+
+
+def get_parameter_device(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).device
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].device
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).dtype
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+
+def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
+ """
+ Reads a checkpoint file, returning properly formatted errors if they arise.
+ """
+ try:
+ if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
+ return torch.load(checkpoint_file, map_location="cpu")
+ except Exception as e:
+ try:
+ with open(checkpoint_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+ "you cloned."
+ )
+ else:
+ raise ValueError(
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+ "model. Make sure you have saved the model properly."
+ ) from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
+ f"at '{checkpoint_file}'. "
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
+ )
+
+
+def _load_state_dict_into_model(model_to_load, state_dict):
+ # Convert old format to new format if needed from a PyTorch state_dict
+ # copy state_dict so _load_from_state_dict can modify it
+ state_dict = state_dict.copy()
+ error_msgs = []
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module: torch.nn.Module, prefix=""):
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
+ module._load_from_state_dict(*args)
+
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ load(model_to_load)
+
+ return error_msgs
+
+
+def _get_model_file(
+ pretrained_model_name_or_path,
+ *,
+ weights_name,
+ subfolder,
+ cache_dir,
+ force_download,
+ proxies,
+ resume_download,
+ local_files_only,
+ use_auth_token,
+ user_agent,
+ revision,
+):
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ if os.path.isfile(pretrained_model_name_or_path):
+ return pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
+ # Load from a PyTorch checkpoint
+ model_file = os.path.join(pretrained_model_name_or_path, weights_name)
+ return model_file
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
+ ):
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
+ return model_file
+ else:
+ raise EnvironmentError(
+ f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=weights_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+ return model_file
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+ "login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+ "this model name. Check the model page at "
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a file named {weights_name} or"
+ " \nCheckout your internet connection or see how to run the library in"
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a file named {weights_name}"
+ )
+
+
+class ModelMixin(torch.nn.Module):
+ r"""
+ Base class for all models.
+
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
+ and saving models.
+
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
+ [`~models.ModelMixin.save_pretrained`].
+ """
+ config_name = CONFIG_NAME
+ _automatically_saved_args = ["_version", "_class_name", "_name_or_path"]
+ _supports_gradient_checkpointing = False
+
+ def __init__(self):
+ super().__init__()
+
+ @property
+ def is_gradient_checkpointing(self) -> bool:
+ """
+ Whether gradient checkpointing is activated for this model or not.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
+
+ def enable_gradient_checkpointing(self):
+ """
+ Activates gradient checkpointing for the current model.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ if not self._supports_gradient_checkpointing:
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
+
+ def disable_gradient_checkpointing(self):
+ """
+ Deactivates gradient checkpointing for the current model.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ if self._supports_gradient_checkpointing:
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
+
+ def set_use_memory_efficient_attention_xformers(
+ self, valid: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ # Recursively walk through all the children.
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
+ # gets the message
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
+
+ for child in module.children():
+ fn_recursive_set_mem_eff(child)
+
+ for module in self.children():
+ if isinstance(module, torch.nn.Module):
+ fn_recursive_set_mem_eff(module)
+
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
+ r"""
+ Enable memory efficient attention as implemented in xformers.
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+ time. Speed up at training time is not guaranteed.
+
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+ is used.
+
+ Parameters:
+ attention_op (`Callable`, *optional*):
+ Override the default `None` operator for use as `op` argument to the
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
+ function of xFormers.
+
+ Examples:
+
+ ```py
+ >>> import torch
+ >>> from diffusers import UNet2DConditionModel
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
+
+ >>> model = UNet2DConditionModel.from_pretrained(
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
+ ... )
+ >>> model = model.to("cuda")
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
+ ```
+ """
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
+
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention as implemented in xformers.
+ """
+ self.set_use_memory_efficient_attention_xformers(False)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = None,
+ state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~models.ModelMixin.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ state_dict (`Dict[str, torch.Tensor]`, *optional*):
+ The state dictionary to save. If `None`, the model's state dictionary will be saved.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ if save_function is None:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ # Attach architecture to the config
+ # Save the config
+ if is_main_process:
+ model_to_save.save_config(save_directory)
+
+ # Save the model
+ if state_dict is None:
+ state_dict = model_to_save.state_dict()
+
+ weights_name = WEIGHTS_NAME
+
+ # Save the model
+ save_function(state_dict, os.path.join(save_directory, weights_name))
+
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ from_flax (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a Flax checkpoint save file.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
+ setting this argument to `True` will raise an error.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", MUSE_CACHE)
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", False) # TODO
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ subfolder = kwargs.pop("subfolder", None)
+ device_map = kwargs.pop("device_map", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ if low_cpu_mem_usage is False and device_map is not None:
+ raise ValueError(
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
+ )
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "pytorch",
+ }
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+ # Load model
+
+ model_file = None
+
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+
+ if low_cpu_mem_usage:
+ # Instantiate model with empty weights
+ with accelerate.init_empty_weights():
+ config, unused_kwargs = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ **kwargs,
+ )
+ model = cls.from_config(config, **unused_kwargs)
+
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
+ if device_map is None:
+ param_device = "cpu"
+ state_dict = load_state_dict(model_file)
+ # move the params from meta device to cpu
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
+ if len(missing_keys) > 0:
+ raise ValueError(
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
+ " those weights or else make sure your checkpoint file is correct."
+ )
+
+ for param_name, param in state_dict.items():
+ accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
+ if accepts_dtype:
+ set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
+ else:
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
+ else: # else let accelerate handle loading and dispatching.
+ # Load weights and dispatch according to the device_map
+ # by deafult the device_map is None and the weights are loaded on the CPU
+ accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
+
+ loading_info = {
+ "missing_keys": [],
+ "unexpected_keys": [],
+ "mismatched_keys": [],
+ "error_msgs": [],
+ }
+ else:
+ config, unused_kwargs = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ **kwargs,
+ )
+ model = cls.from_config(config, **unused_kwargs)
+
+ state_dict = load_state_dict(model_file)
+
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ )
+
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+ if output_loading_info:
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model(
+ cls,
+ model,
+ state_dict,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ ):
+ # Retrieve missing & unexpected_keys
+ model_state_dict = model.state_dict()
+ loaded_keys = [k for k in state_dict.keys()]
+
+ expected_keys = list(model_state_dict.keys())
+
+ original_loaded_keys = loaded_keys
+
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ model_to_load = model
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ if state_dict is not None:
+ # Whole checkpoint
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ original_loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += (
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ )
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+ " identical (initializing a BertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+ " without further training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
+ )
+
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
+ @property
+ def device(self) -> device:
+ """
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+ device).
+ """
+ return get_parameter_device(self)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ return get_parameter_dtype(self)
+
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+ """
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+ Args:
+ only_trainable (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of trainable parameters
+
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of non-embeddings parameters
+
+ Returns:
+ `int`: The number of parameters.
+ """
+
+ if exclude_embeddings:
+ embedding_param_names = [
+ f"{name}.weight"
+ for name, module_type in self.named_modules()
+ if isinstance(module_type, torch.nn.Embedding)
+ ]
+ non_embedding_parameters = [
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+ ]
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
+ else:
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+
+
+""" ConfigMixin base class and utilities."""
+
+
+class FrozenDict(OrderedDict):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for key, value in self.items():
+ setattr(self, key, value)
+
+ self.__frozen = True
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __setattr__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setitem__(name, value)
+
+
+class ConfigMixin:
+ r"""
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
+ - [`~ConfigMixin.from_config`]
+ - [`~ConfigMixin.save_config`]
+
+ Class attributes:
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
+ overridden by subclass).
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
+ subclass).
+ """
+ config_name = None
+ ignore_for_config = []
+ has_compatibles = False
+
+ _deprecated_kwargs = []
+
+ def register_to_config(self, **kwargs):
+ if self.config_name is None:
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
+ # Special case for `kwargs` used in deprecation warning added to schedulers
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
+ # or solve in a more general way.
+ kwargs.pop("kwargs", None)
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ if not hasattr(self, "_internal_dict"):
+ internal_dict = kwargs
+ else:
+ previous_dict = dict(self._internal_dict)
+ internal_dict = {**self._internal_dict, **kwargs}
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
+
+ self._internal_dict = FrozenDict(internal_dict)
+
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~ConfigMixin.from_config`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"Configuration saved in {output_config_file}")
+
+ @classmethod
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, **kwargs):
+ r"""
+ Instantiate a Python class from a config dictionary
+
+ Parameters:
+ config (`Dict[str, Any]`):
+ A config dictionary from which the Python class will be instantiated. Make sure to only load
+ configuration files of compatible classes.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the Python class.
+ `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
+ overwrite same named arguments of `config`.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
+
+ >>> # Download scheduler from huggingface.co and cache.
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
+
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
+
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
+ ```
+ """
+ # <===== TO BE REMOVED WITH DEPRECATION
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
+ if "pretrained_model_name_or_path" in kwargs:
+ config = kwargs.pop("pretrained_model_name_or_path")
+
+ if config is None:
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
+ # ======>
+
+ # Return model and optionally state and/or unused_kwargs
+ model = cls(**config)
+ return model
+
+ @classmethod
+ def load_config(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ r"""
+ Instantiate a Python class from a config dictionary
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", MUSE_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ _ = kwargs.pop("mirror", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {"file_type": "config"}
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ ):
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
+ " login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+
+ try:
+ # Load config dict
+ config_dict = cls._dict_from_json_file(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+ if return_unused_kwargs:
+ return config_dict, kwargs
+
+ return config_dict
+
+ @staticmethod
+ def _get_init_keys(cls):
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @property
+ def config(self) -> Dict[str, Any]:
+ """
+ Returns the config of the class as a frozen dictionary
+
+ Returns:
+ `Dict[str, Any]`: Config of the class.
+ """
+ return self._internal_dict
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
+ config_dict["_class_name"] = self.__class__.__name__
+ config_dict["_version"] = __version__
+
+ def to_json_saveable(value):
+ if isinstance(value, np.ndarray):
+ value = value.tolist()
+ elif isinstance(value, PosixPath):
+ value = str(value)
+ return value
+
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+
+def register_to_config(init):
+ r"""
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
+
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
+ """
+
+ @functools.wraps(init)
+ def inner_init(self, *args, **kwargs):
+ # Ignore private kwargs in the init.
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
+
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ ignore = getattr(self, "ignore_for_config", [])
+ # Get positional arguments aligned with kwargs
+ new_kwargs = {}
+ signature = inspect.signature(init)
+ parameters = {
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
+ }
+ for arg, name in zip(args, parameters.keys()):
+ new_kwargs[name] = arg
+
+ # Then add all kwargs
+ new_kwargs.update(
+ {
+ k: init_kwargs.get(k, default)
+ for k, default in parameters.items()
+ if k not in ignore and k not in new_kwargs
+ }
+ )
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
+ getattr(self, "register_to_config")(**new_kwargs)
+ init(self, *args, **init_kwargs)
+
+ return inner_init
diff --git a/vqvae_muse.py b/vqvae_muse.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2d57f7e54c46b71c75bff0124f02d6df146d7cb
--- /dev/null
+++ b/vqvae_muse.py
@@ -0,0 +1,594 @@
+# coding=utf-8
+# Copyright 2023 The Taming Transformers Authors and The HuggingFace Inc. team.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from functools import partial
+from typing import Tuple
+import os
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+
+from vqvae.modeling_utils import ConfigMixin, ModelMixin, register_to_config
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels: int, with_conv: bool):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels: int, with_conv: bool):
+ super().__init__()
+
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, hidden_states):
+ if self.with_conv:
+ pad = (0, 1, 0, 1) # pad height and width dim
+ hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2)
+ return hidden_states
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int = None,
+ use_conv_shortcut: bool = False,
+ dropout_prob: float = 0.0,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
+ self.use_conv_shortcut = use_conv_shortcut
+
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.conv1 = nn.Conv2d(
+ self.in_channels,
+ self.out_channels_,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True)
+ self.dropout = nn.Dropout(dropout_prob)
+ self.conv2 = nn.Conv2d(
+ self.out_channels_,
+ self.out_channels_,
+ kernel_size=3,
+ stride=(1, 1),
+ padding=1,
+ )
+
+ if self.in_channels != self.out_channels_:
+ if use_conv_shortcut:
+ self.conv_shortcut = nn.Conv2d(
+ self.in_channels,
+ self.out_channels_,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+ else:
+ self.nin_shortcut = nn.Conv2d(
+ self.in_channels,
+ self.out_channels_,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.in_channels != self.out_channels_:
+ if self.use_conv_shortcut:
+ residual = self.conv_shortcut(residual)
+ else:
+ residual = self.nin_shortcut(residual)
+
+ return hidden_states + residual
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+
+ self.in_channels = in_channels
+ conv = partial(nn.Conv2d, self.in_channels, self.in_channels, kernel_size=1, stride=1, padding=0)
+
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True)
+ self.q, self.k, self.v = conv(), conv(), conv()
+ self.proj_out = conv()
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ hidden_states = self.norm(hidden_states)
+
+ query = self.q(hidden_states)
+ key = self.k(hidden_states)
+ value = self.v(hidden_states)
+
+ # compute attentions
+ batch, channels, height, width = query.shape
+ query = query.reshape((batch, channels, height * width))
+ query = query.permute(0, 2, 1) # (b, hw, c)
+ key = key.reshape((batch, channels, height * width))
+
+ attn_weights = torch.bmm(query, key) # b,hw,hw
+ attn_weights = attn_weights * (int(channels) ** -0.5)
+ attn_weights = nn.functional.softmax(attn_weights, dim=2)
+
+ # attend to values
+ value = value.reshape((batch, channels, height * width))
+ attn_weights = attn_weights.permute(0, 2, 1)
+ hidden_states = torch.bmm(value, attn_weights)
+ hidden_states = hidden_states.reshape((batch, channels, height, width))
+
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states + residual
+ return hidden_states
+
+
+class UpsamplingBlock(nn.Module):
+ def __init__(self, config, curr_res: int, block_idx: int):
+ super().__init__()
+
+ self.config = config
+ self.block_idx = block_idx
+ self.curr_res = curr_res
+
+ if self.block_idx == self.config.num_resolutions - 1:
+ block_in = self.config.hidden_channels * self.config.channel_mult[-1]
+ else:
+ block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1]
+
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
+
+ res_blocks = []
+ attn_blocks = []
+ for _ in range(self.config.num_res_blocks + 1):
+ res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
+ block_in = block_out
+ if self.curr_res in self.config.attn_resolutions:
+ attn_blocks.append(AttnBlock(block_in))
+
+ self.block = nn.ModuleList(res_blocks)
+ self.attn = nn.ModuleList(attn_blocks)
+
+ self.upsample = None
+ if self.block_idx != 0:
+ self.upsample = Upsample(block_in, self.config.resample_with_conv)
+
+ def forward(self, hidden_states):
+ for i, res_block in enumerate(self.block):
+ hidden_states = res_block(hidden_states)
+ if len(self.attn) > 1:
+ hidden_states = self.attn[i](hidden_states)
+
+ if self.upsample is not None:
+ hidden_states = self.upsample(hidden_states)
+
+ return hidden_states
+
+
+class DownsamplingBlock(nn.Module):
+ def __init__(self, config, curr_res: int, block_idx: int):
+ super().__init__()
+
+ self.config = config
+ self.curr_res = curr_res
+ self.block_idx = block_idx
+
+ in_channel_mult = (1,) + tuple(self.config.channel_mult)
+ block_in = self.config.hidden_channels * in_channel_mult[self.block_idx]
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
+
+ res_blocks = nn.ModuleList()
+ attn_blocks = nn.ModuleList()
+ for _ in range(self.config.num_res_blocks):
+ res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
+ block_in = block_out
+ if self.curr_res in self.config.attn_resolutions:
+ attn_blocks.append(AttnBlock(block_in))
+
+ self.block = res_blocks
+ self.attn = attn_blocks
+
+ self.downsample = None
+ if self.block_idx != self.config.num_resolutions - 1:
+ self.downsample = Downsample(block_in, self.config.resample_with_conv)
+
+ def forward(self, hidden_states):
+ for i, res_block in enumerate(self.block):
+ hidden_states = res_block(hidden_states)
+ if len(self.attn) > 1:
+ hidden_states = self.attn[i](hidden_states)
+
+ if self.downsample is not None:
+ hidden_states = self.downsample(hidden_states)
+
+ return hidden_states
+
+
+class MidBlock(nn.Module):
+ def __init__(self, config, in_channels: int, no_attn: False, dropout: float):
+ super().__init__()
+
+ self.config = config
+ self.in_channels = in_channels
+ self.no_attn = no_attn
+ self.dropout = dropout
+
+ self.block_1 = ResnetBlock(
+ self.in_channels,
+ self.in_channels,
+ dropout_prob=self.dropout,
+ )
+ if not no_attn:
+ self.attn_1 = AttnBlock(self.in_channels)
+ self.block_2 = ResnetBlock(
+ self.in_channels,
+ self.in_channels,
+ dropout_prob=self.dropout,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.block_1(hidden_states)
+ if not self.no_attn:
+ hidden_states = self.attn_1(hidden_states)
+ hidden_states = self.block_2(hidden_states)
+ return hidden_states
+
+
+class Encoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+
+ # downsampling
+ self.conv_in = nn.Conv2d(
+ self.config.num_channels,
+ self.config.hidden_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ curr_res = self.config.resolution
+ downsample_blocks = []
+ for i_level in range(self.config.num_resolutions):
+ downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level))
+
+ if i_level != self.config.num_resolutions - 1:
+ curr_res = curr_res // 2
+ self.down = nn.ModuleList(downsample_blocks)
+
+ # middle
+ mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
+ self.mid = MidBlock(config, mid_channels, self.config.no_attn_mid_block, self.config.dropout)
+
+ # end
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(
+ mid_channels,
+ self.config.z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, pixel_values):
+ # downsampling
+ hidden_states = self.conv_in(pixel_values)
+ for block in self.down:
+ hidden_states = block(hidden_states)
+
+ # middle
+ hidden_states = self.mid(hidden_states)
+
+ # end
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+class Decoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+
+ # compute in_channel_mult, block_in and curr_res at lowest res
+ block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1]
+ curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
+ self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
+
+ # z to block_in
+ self.conv_in = nn.Conv2d(
+ self.config.z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ # middle
+ self.mid = MidBlock(config, block_in, self.config.no_attn_mid_block, self.config.dropout)
+
+ # upsampling
+ upsample_blocks = []
+ for i_level in reversed(range(self.config.num_resolutions)):
+ upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level))
+ if i_level != 0:
+ curr_res = curr_res * 2
+ self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order
+
+ # end
+ block_out = self.config.hidden_channels * self.config.channel_mult[0]
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(
+ block_out,
+ self.config.num_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, hidden_states):
+ # z to block_in
+ hidden_states = self.conv_in(hidden_states)
+
+ # middle
+ hidden_states = self.mid(hidden_states)
+
+ # upsampling
+ for block in reversed(self.up):
+ hidden_states = block(hidden_states)
+
+ # end
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+class VectorQuantizer(nn.Module):
+ """
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+ Discretization bottleneck part of the VQ-VAE.
+ """
+
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost):
+ r"""
+ Args:
+ num_embeddings: number of vectors in the quantized space.
+ embedding_dim: dimensionality of the tensors in the quantized space.
+ Inputs to the modules must be in this format as well.
+ commitment_cost: scalar which controls the weighting of the loss terms
+ (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta).
+ """
+ super().__init__()
+
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.commitment_cost = commitment_cost
+
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
+ self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
+
+ def forward(self, hidden_states, return_loss=False):
+ """
+ Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the
+ closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width)
+ quantization pipeline:
+ 1. get encoder input (B,C,H,W)
+ 2. flatten input to (B*H*W,C)
+ """
+ # reshape z -> (batch, height, width, channel) and flatten
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
+
+ distances = self.compute_distances(hidden_states)
+ min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)
+
+ # reshape to (batch, num_tokens)
+ min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
+
+ # compute loss for embedding
+ loss = None
+ if return_loss:
+ loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean(
+ (z_q - hidden_states.detach()) ** 2
+ )
+ # preserve gradients
+ z_q = hidden_states + (z_q - hidden_states).detach()
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, min_encoding_indices, loss
+
+ def compute_distances(self, hidden_states):
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim))
+ emb_weights = self.embedding.weight.t()
+
+ inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True)
+ codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True)
+ distances = torch.addmm(
+ inputs_norm_sq + codebook_t_norm_sq,
+ hidden_states_flattended,
+ emb_weights,
+ alpha=-2.0,
+ )
+ return distances
+
+ def get_codebook_entry(self, indices):
+ # indices are expected to be of shape (batch, num_tokens)
+ # get quantized latent vectors
+ batch, num_tokens = indices.shape
+ z_q = self.embedding(indices)
+ z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2)
+ return z_q
+
+ # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372
+ def get_soft_code(self, hidden_states, temp=1.0, stochastic=False):
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel)
+ distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings)
+
+ soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings)
+ if stochastic:
+ code = torch.multinomial(soft_code, 1) # (batch * height * width, 1)
+ else:
+ code = distances.argmin(dim=-1) # (batch * height * width)
+
+ code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width)
+ batch, num_tokens = code.shape
+ soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings)
+ return soft_code, code
+
+ def get_code(self, hidden_states):
+ # reshape z -> (batch, height, width, channel)
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
+ distances = self.compute_distances(hidden_states)
+ indices = torch.argmin(distances, axis=1).unsqueeze(1)
+ indices = indices.reshape(hidden_states.shape[0], -1)
+ return indices
+
+
+class VQGANModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ resolution: int = 256,
+ num_channels: int = 3,
+ hidden_channels: int = 128,
+ channel_mult: Tuple = (1, 1, 2, 2, 4),
+ num_res_blocks: int = 2,
+ attn_resolutions: int = (16,),
+ no_attn_mid_block: bool = False,
+ z_channels: int = 256,
+ num_embeddings: int = 1024,
+ quantized_embed_dim: int = 256,
+ dropout: float = 0.0,
+ resample_with_conv: bool = True,
+ commitment_cost: float = 0.25,
+ ):
+ super().__init__()
+
+ self.config.num_resolutions = len(channel_mult)
+ self.config.reduction_factor = 2 ** (self.config.num_resolutions - 1)
+ self.config.latent_size = resolution // self.config.reduction_factor
+
+ self.encoder = Encoder(self.config)
+ self.decoder = Decoder(self.config)
+ self.quantize = VectorQuantizer(
+ self.config.num_embeddings, self.config.quantized_embed_dim, self.config.commitment_cost
+ )
+ self.quant_conv = nn.Conv2d(
+ self.config.z_channels,
+ self.config.quantized_embed_dim,
+ kernel_size=1,
+ )
+ self.post_quant_conv = nn.Conv2d(
+ self.config.quantized_embed_dim,
+ self.config.z_channels,
+ kernel_size=1,
+ )
+
+ def encode(self, pixel_values, return_loss=False):
+ hidden_states = self.encoder(pixel_values)
+ hidden_states = self.quant_conv(hidden_states)
+ quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
+ output = (quantized_states, codebook_indices)
+ if return_loss:
+ output = output + (codebook_loss,)
+ return output
+
+ def decode(self, quantized_states):
+ hidden_states = self.post_quant_conv(quantized_states)
+ reconstructed_pixel_values = self.decoder(hidden_states)
+ return reconstructed_pixel_values
+
+ def decode_code(self, codebook_indices):
+ quantized_states = self.quantize.get_codebook_entry(codebook_indices)
+ reconstructed_pixel_values = self.decode(quantized_states)
+ return reconstructed_pixel_values
+
+ def get_code(self, pixel_values):
+ hidden_states = self.encoder(pixel_values)
+ hidden_states = self.quant_conv(hidden_states)
+ codebook_indices = self.quantize.get_code(hidden_states)
+ return codebook_indices
+
+ def forward(self, pixel_values, return_loss=False):
+ hidden_states = self.encoder(pixel_values)
+ hidden_states = self.quant_conv(hidden_states)
+ quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
+ reconstructed_pixel_values = self.decode(quantized_states)
+ outputs = (reconstructed_pixel_values, quantized_states, codebook_indices)
+ if return_loss:
+ outputs = outputs + (codebook_loss,)
+ return outputs
+
+
+
+def get_tokenizer_muse():
+
+ ckpts_path = "Emma02/vqvae_ckpts"
+ net = VQGANModel.from_pretrained(ckpts_path)
+
+ return net