diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f6b1f326ca4ab7cf0c8798856f8fe0020ff82d58 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index a8cc3865dcc72ed09a6108308354af98f9eeffeb..04ac80b975ace53c2738b1053f5d4d9cf54d303f 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,11 @@ --- -title: LVM -emoji: 🔥 -colorFrom: yellow -colorTo: gray +title: VQLM Demo +emoji: 🎨 +colorFrom: "yellow" +colorTo: "blue" sdk: gradio -sdk_version: 4.36.1 +sdk_version: "4.29.0" app_file: app.py pinned: false -license: apache-2.0 --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7d34083aa00b5137708967eba4245149d4545eb6 --- /dev/null +++ b/app.py @@ -0,0 +1,244 @@ +import gradio as gr +import numpy as np +import mlxu +import os +import re +import torch + +from io import BytesIO +from natsort import natsorted +from PIL import Image + +from inference import LocalInferenceModel + +FLAGS, _ = mlxu.define_flags_with_default( + host='0.0.0.0', + port=5000, + dtype='float16', + checkpoint='Emma02/LVM_ckpts', + torch_devices='', + context_frames=16, +) + +def natural_sort_key(s): + return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)] + +def load_example_image_groups(directory): + example_groups = {} + for subdir in os.listdir(directory): + subdir_path = os.path.join(directory, subdir) + if os.path.isdir(subdir_path): + example_groups[subdir] = [] + images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + images = natsorted(images, key=natural_sort_key) + for filename in images: + img = Image.open(os.path.join(subdir_path, filename)) + example_groups[subdir].append(img) + return example_groups + +def main(_): + assert FLAGS.checkpoint != '' + + model = LocalInferenceModel( + checkpoint=FLAGS.checkpoint, + torch_device=torch.device("cuda"), + dtype=FLAGS.dtype, + context_frames=FLAGS.context_frames, + use_lock=False, + ) + + checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1) + checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1) + checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32) + + def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9): + assert len(input_images) > 0 + input_images = [ + np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0 + for img in input_images + ] + input_images = np.stack(input_images, axis=0) + output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0] + + generated_images = [] + for candidate in output_images: + concatenated_image = [] + for i, img in enumerate(candidate): + concatenated_image.append(img) + if i < len(candidate) - 1: + concatenated_image.append(checkerboard) + generated_images.append( + Image.fromarray( + (np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8) + ) + ) + + return generated_images + + with gr.Blocks(css=""" + .small-button { + padding: 5px 10px; + min-width: 80px; + } + .large-gallery img { + width: 100%; + height: auto; + max-height: 150px; + } + """) as demo: + with gr.Column(): + image_list = gr.State([]) + gr.Markdown('# VQLM Demo') + gr.Markdown(f'Serving model: {FLAGS.checkpoint}') + gr.Markdown('## Inputs') + with gr.Row(): + upload_drag = gr.File( + type='binary', + file_types=['image'], + file_count='multiple', + ) + with gr.Column(): + gen_length_slider = gr.Slider( + label='Generation length', + minimum=1, + maximum=32, + value=1, + step=1, + interactive=True, + ) + n_candidates_slider = gr.Slider( + label='Number of candidates', + minimum=1, + maximum=10, + value=1, + step=1, + interactive=True, + ) + temp_slider = gr.Slider( + label='Temperature', + minimum=0, + maximum=2.0, + value=1.0, + interactive=True, + ) + top_p_slider = gr.Slider( + label='Top p', + minimum=0, + maximum=1.0, + value=0.9, + interactive=True, + ) + clear_btn = gr.Button( + value='Clear', + elem_classes=['small-button'], + ) + generate_btn = gr.Button( + value='Generate', + interactive=False, + elem_classes=['small-button'], + ) + input_gallery = gr.Gallery( + columns=7, + rows=1, + object_fit='scale-down', + label="Input image sequence" + ) + gr.Markdown('## Outputs') + output_gallery = gr.Gallery( + columns=4, + object_fit='scale-down', + label="Output image" + ) + + def upload_image_fn(files, images): + for file in files: + images.append(Image.open(BytesIO(file))) + + return { + upload_drag: None, + image_list: images, + input_gallery: images, + generate_btn: gr.update(interactive=True), + } + + def clear_fn(): + return { + image_list: [], + input_gallery: [], + generate_btn: gr.update(interactive=False), + output_gallery: [], + } + + def disable_generate_btn(): + return { + generate_btn: gr.update(interactive=False), + } + + def generate_fn(images, n_candidates, gen_length, temperature, top_p): + new_images = generate_images( + images, + gen_length, + n_candidates=n_candidates, + temperature=temperature, + top_p=top_p, + ) + return { + output_gallery: new_images, + generate_btn: gr.update(interactive=True), + } + + upload_drag.upload( + upload_image_fn, + inputs=[upload_drag, image_list], + outputs=[upload_drag, image_list, input_gallery, generate_btn], + ) + clear_btn.click( + clear_fn, + inputs=None, + outputs=[image_list, input_gallery, generate_btn, output_gallery], + ) + generate_btn.click( + disable_generate_btn, + inputs=None, + outputs=[generate_btn], + ).then( + generate_fn, + inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider], + outputs=[output_gallery, generate_btn], + ) + + example_groups = load_example_image_groups('prompts') + + def add_image_group_fn(group_name, images): + new_images = images + example_groups[group_name] + return { + image_list: new_images, + input_gallery: new_images, + generate_btn: gr.update(interactive=True), + } + + for group_name, group_images in example_groups.items(): + with gr.Row(): + with gr.Column(scale=3): + add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button']) + with gr.Column(scale=7): + group_gallery = gr.Gallery( + value=[Image.fromarray(np.array(img)) for img in group_images], + columns=5, + rows=1, + object_fit='scale-down', + label=group_name, + elem_classes=['large-gallery'], + ) + + add_button.click( + add_image_group_fn, + inputs=[gr.State(group_name), image_list], + outputs=[image_list, input_gallery, generate_btn], + ) + + demo.launch() + +if __name__ == "__main__": + mlxu.run(main) + diff --git a/batch_generation.py b/batch_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..3c044af5d192ad050be4199463085cb8df214dc8 --- /dev/null +++ b/batch_generation.py @@ -0,0 +1,223 @@ +""" +Batch generation for sequnce of images. This script accept a jsonl file +as input. Each line of the jsonl file representing a dictionary. Each line +represents one example in the evaluation set. The dictionary should have two key: + + input: a list of paths to the input images as context to the model. + output: a string representing the path to the output of generation to be saved. + +Ths script runs the mode to generate the output images, and concatenate the +input and output images together and save them to the output path. +""" + +import os +import json +from PIL import Image +import numpy as np +import mlxu +from tqdm import tqdm, trange +from multiprocessing import Pool +import einops +import torch + +from .inference import MultiProcessInferenceModel +from .utils import read_image_to_tensor, MultiProcessImageSaver + + +FLAGS, _ = mlxu.define_flags_with_default( + input_file='', + checkpoint='', + input_base_dir='', + output_base_dir='', + evaluate_mse=False, + json_input_key='input', + json_output_key='output', + json_target_key='target', + n_new_frames=1, + n_candidates=2, + context_frames=16, + temperature=1.0, + top_p=1.0, + n_workers=8, + dtype='float16', + torch_devices='', + batch_size_factor=4, + max_examples=0, + resize_output='', + include_input=False, +) + +# create this according to the json file. +class MultiFrameDataset(torch.utils.data.Dataset): + def __init__(self, input_files, output_files, target_files=None): + assert len(input_files) + self.input_files = input_files + self.output_files = output_files + self.target_files = target_files + + def __len__(self): + return len(self.input_files) + + def __getitem__(self, idx): + original_size = Image.open(self.input_files[idx][-1]).size + input_images = np.stack( + [read_image_to_tensor(f) for f in self.input_files[idx]], + axis=0 + ) + + if self.target_files is not None: + target_images = np.stack( + [read_image_to_tensor(f) for f in self.target_files[idx]], + axis=0 + ) + else: + target_images = None + return input_images, target_images, self.output_files[idx], np.array(original_size) + + +def main(_): + assert FLAGS.checkpoint != '' + + print(f'Loading checkpoint from {FLAGS.checkpoint}') + print(f'Evaluating input file from {FLAGS.input_file}') + + # build a model. + + model = MultiProcessInferenceModel( + checkpoint=FLAGS.checkpoint, + torch_devices=FLAGS.torch_devices, + dtype=FLAGS.dtype, + context_frames=FLAGS.context_frames, + use_lock=True, + ) + + # input_files: the json file that needs to be generated by the other file. + input_files = [] + output_files = [] + + if FLAGS.evaluate_mse: + target_files = [] + else: + target_files = None + + with mlxu.open_file(FLAGS.input_file, 'r') as f: + for line in f: + record = json.loads(line) + input_files.append(record[FLAGS.json_input_key]) + output_files.append(record[FLAGS.json_output_key]) + if FLAGS.evaluate_mse: + target_files.append(record[FLAGS.json_target_key]) + + + if FLAGS.max_examples > 0: + input_files = input_files[:FLAGS.max_examples] + output_files = output_files[:FLAGS.max_examples] + if FLAGS.evaluate_mse: + target_files = target_files[:FLAGS.max_examples] + + if FLAGS.input_base_dir != '': + input_files = [ + [os.path.join(FLAGS.input_base_dir, x) for x in y] + for y in input_files + ] + if FLAGS.evaluate_mse: + target_files = [ + [os.path.join(FLAGS.input_base_dir, x) for x in y] + for y in target_files + ] + + if FLAGS.output_base_dir != '': + os.makedirs(FLAGS.output_base_dir, exist_ok=True) + output_files = [ + os.path.join(FLAGS.output_base_dir, x) + for x in output_files + ] + + dataset = MultiFrameDataset(input_files, output_files, target_files) + + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=FLAGS.batch_size_factor * model.n_processes, + shuffle=False, + num_workers=FLAGS.n_workers, + ) + + image_saver = MultiProcessImageSaver(FLAGS.n_workers) + + mses = [] + + for batch_images, batch_targets, batch_output_files, batch_sizes in tqdm(data_loader, ncols=0): + + # batch_images is input. + batch_images = batch_images.numpy() + + # + context_length = batch_images.shape[1] + + + generated_images = model( + batch_images, + FLAGS.n_new_frames, + FLAGS.n_candidates, + temperature=FLAGS.temperature, + top_p=FLAGS.top_p + ) + + + repeated_batch = einops.repeat( + batch_images, + 'b s h w c -> b n s h w c', + n=FLAGS.n_candidates, + ) + generated_images = np.array(generated_images) + + if FLAGS.evaluate_mse: + batch_targets = einops.repeat( + batch_targets.numpy(), + 'b s h w c -> b n s h w c', # batch, candidate, s + n=FLAGS.n_candidates, + ) + channels = batch_targets.shape[-1] + # calculate mse loss. + mse = np.mean((generated_images - batch_targets) ** 2, axis=(1, 2, 3, 4, 5)) + + mses.append(mse * channels) + + + if FLAGS.include_input: + combined = einops.rearrange( + np.concatenate([repeated_batch, generated_images], axis=2), + 'b n s h w c -> b (n h) (s w) c' + ) + else: + combined = einops.rearrange( + generated_images, + 'b n s h w c -> b (n h) (s w) c' + ) + combined = (combined * 255).astype(np.uint8) + + n_frames = FLAGS.n_new_frames + if FLAGS.include_input: + n_frames += context_length + + if FLAGS.resize_output == '': + resizes = None + + elif FLAGS.resize_output == 'original': + resizes = batch_sizes.numpy() + resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]]) + else: + resize = tuple(int(x) for x in FLAGS.resize_output.split(',')) + resizes = np.array([resize] * len(batch_sizes)) + resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]]) + + image_saver(combined, batch_output_files, resizes) + + if FLAGS.evaluate_mse: + mses = np.concatenate(mses, axis=0) + print(f'MSE: {np.mean(mses)}') + + image_saver.close() + +if __name__ == "__main__": + mlxu.run(main) \ No newline at end of file diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..57060a49b8f8b92ef4243154e841268c5a3fbc28 --- /dev/null +++ b/demo.py @@ -0,0 +1,263 @@ +import re +from natsort import natsorted + +def natural_sort_key(s): + return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)] + +def load_example_image_groups(directory): + example_groups = {} + for subdir in os.listdir(directory): + subdir_path = os.path.join(directory, subdir) + if os.path.isdir(subdir_path): + example_groups[subdir] = [] + images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + images = natsorted(images, key=natural_sort_key) # Natural sorting + for filename in images: + img = Image.open(os.path.join(subdir_path, filename)) + example_groups[subdir].append(img) + return example_groups + + +from io import BytesIO +import gradio as gr +import uvicorn +from fastapi import FastAPI +from PIL import Image +import numpy as np +import mlxu +import os +import re +from natsort import natsorted + +from .inference import MultiProcessInferenceModel + +FLAGS, _ = mlxu.define_flags_with_default( + host='0.0.0.0', + port=5007, + dtype='float16', + checkpoint='', + torch_devices='', + context_frames=16, +) + +def natural_sort_key(s): + return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)] + +def load_example_image_groups(directory): + example_groups = {} + for subdir in os.listdir(directory): + subdir_path = os.path.join(directory, subdir) + if os.path.isdir(subdir_path): + example_groups[subdir] = [] + images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + images = natsorted(images, key=natural_sort_key) # Natural sorting + for filename in images: + img = Image.open(os.path.join(subdir_path, filename)) + example_groups[subdir].append(img) + return example_groups + +def main(_): + assert FLAGS.checkpoint != '' + + model = MultiProcessInferenceModel( + checkpoint=FLAGS.checkpoint, + torch_devices=FLAGS.torch_devices, + dtype=FLAGS.dtype, + context_frames=FLAGS.context_frames, + use_lock=True, + ) + + checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1) + checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1) + checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32) + + def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9): + assert len(input_images) > 0 + input_images = [ + np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0 + for img in input_images + ] + input_images = np.stack(input_images, axis=0) + output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0] + + generated_images = [] + for candidate in output_images: + concatenated_image = [] + for i, img in enumerate(candidate): + concatenated_image.append(img) + if i < len(candidate) - 1: + concatenated_image.append(checkerboard) + generated_images.append( + Image.fromarray( + (np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8) + ) + ) + + return generated_images + + with gr.Blocks(css=""" + .small-button { + padding: 5px 10px; + min-width: 80px; + } + .large-gallery img { + width: 100%; + height: auto; + max-height: 150px; + } + """) as demo: + with gr.Column(): + image_list = gr.State([]) + gr.Markdown('# LVM Demo') + gr.Markdown(f'Serving model: {FLAGS.checkpoint}') + gr.Markdown('## Inputs') + with gr.Row(): + upload_drag = gr.File( + type='binary', + file_types=['image'], + file_count='multiple', + ) + with gr.Column(): + gen_length_slider = gr.Slider( + label='Generation length', + minimum=1, + maximum=32, + value=1, + step=1, + interactive=True, + ) + n_candidates_slider = gr.Slider( + label='Number of candidates', + minimum=1, + maximum=10, + value=1, + step=1, + interactive=True, + ) + temp_slider = gr.Slider( + label='Temperature', + minimum=0, + maximum=2.0, + value=1.0, + interactive=True, + ) + top_p_slider = gr.Slider( + label='Top p', + minimum=0, + maximum=1.0, + value=0.9, + interactive=True, + ) + clear_btn = gr.Button( + value='Clear', + elem_classes=['small-button'], + ) + generate_btn = gr.Button( + value='Generate', + interactive=False, + elem_classes=['small-button'], + ) + input_gallery = gr.Gallery( + columns=7, + rows=1, + object_fit='scale-down', + ) + gr.Markdown('## Outputs') + output_gallery = gr.Gallery( + columns=4, + object_fit='scale-down', + ) + + def upload_image_fn(files, images): + for file in files: + images.append(Image.open(BytesIO(file))) + + return { + upload_drag: None, + image_list: images, + input_gallery: images, + generate_btn: gr.update(interactive=True), + } + + def clear_fn(): + return { + image_list: [], + input_gallery: [], + generate_btn: gr.update(interactive=False), + output_gallery: [], + } + + def disable_generate_btn(): + return { + generate_btn: gr.update(interactive=False), + } + + def generate_fn(images, n_candidates, gen_length, temperature, top_p): + new_images = generate_images( + images, + gen_length, + n_candidates=n_candidates, + temperature=temperature, + top_p=top_p, + ) + return { + output_gallery: new_images, + generate_btn: gr.update(interactive=True), + } + + upload_drag.upload( + upload_image_fn, + inputs=[upload_drag, image_list], + outputs=[upload_drag, image_list, input_gallery, generate_btn], + ) + clear_btn.click( + clear_fn, + inputs=None, + outputs=[image_list, input_gallery, generate_btn, output_gallery], + ) + generate_btn.click( + disable_generate_btn, + inputs=None, + outputs=[generate_btn], + ).then( + generate_fn, + inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider], + outputs=[output_gallery, generate_btn], + ) + + example_groups = load_example_image_groups('/home/yutongbai/demo_images') + + def add_image_group_fn(group_name, images): + new_images = images + example_groups[group_name] + return { + image_list: new_images, + input_gallery: new_images, + generate_btn: gr.update(interactive=True), + } + + for group_name, group_images in example_groups.items(): + with gr.Row(): + with gr.Column(scale=3): + add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button']) + with gr.Column(scale=7): + group_gallery = gr.Gallery( + value=[Image.fromarray(np.array(img)) for img in group_images], + columns=5, + rows=1, + object_fit='scale-down', + label=group_name, + elem_classes=['large-gallery'], + ) + + add_button.click( + add_image_group_fn, + inputs=[gr.State(group_name), image_list], + outputs=[image_list, input_gallery, generate_btn], + ) + + app = FastAPI() + app = gr.mount_gradio_app(app, demo, '/') + uvicorn.run(app, host=FLAGS.host, port=FLAGS.port) + +if __name__ == "__main__": + mlxu.run(main) diff --git a/eval_perplexity.py b/eval_perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdefb882fc039d294209f0bb5e254ede100f340 --- /dev/null +++ b/eval_perplexity.py @@ -0,0 +1,127 @@ +""" +Evaluating the perplexity on few shot tasks. This script accept a jsonl file +as input. Each line of the jsonl file representing a dictionary. Each line +represents one example in the evaluation set. The dictionary should have two key: + + input: a list of paths to the input images as context to the model. This + list should include the few shot examples. + target: a list of paths to the target images to evaluate perplexity + +Ths script should run the model and compute the average perplexity on the +evaluation set. +""" + +import os +import json +from PIL import Image +import numpy as np +import mlxu +from tqdm import tqdm, trange +import torch +import torch.nn as nn +import torch.nn.functional as F +import einops + +from .inference import MultiProcessInferenceModel + + +FLAGS, _ = mlxu.define_flags_with_default( + input_file='', + checkpoint='', + input_base_dir='', + batch_size=2, + json_input_key='input', + json_target_key='target', + dtype='float16', + torch_devices='', + n_workers=4, + max_examples=0, +) + + +def read_image_to_tensor(path): + pil_im = Image.open(path).convert('RGB') + input_img = pil_im.resize((256, 256)) + input_img = np.array(input_img) / 255.0 + input_img = input_img.astype(np.float32) + return input_img + + +class MultiFrameDataset(torch.utils.data.Dataset): + def __init__(self, input_files, target_files): + assert len(input_files) == len(target_files) + self.input_files = input_files + self.target_files = target_files + + def __len__(self): + return len(self.input_files) + + def __getitem__(self, idx): + input_list = np.stack( + [read_image_to_tensor(f) for f in self.input_files[idx]], + axis=0 + ) + target_list = np.stack( + [read_image_to_tensor(f) for f in self.target_files[idx]], + axis=0 + ) + return input_list, target_list + + +def main(_): + assert FLAGS.checkpoint != '' + + print(f'Loading checkpoint from {FLAGS.checkpoint}') + print(f'Evaluating input file from {FLAGS.input_file}') + + model = MultiProcessInferenceModel( + checkpoint=FLAGS.checkpoint, + torch_devices=FLAGS.torch_devices, + dtype=FLAGS.dtype, + use_lock=True, + perplexity_batch_size=FLAGS.batch_size, + ) + + input_files = [] + target_files = [] + + with mlxu.open_file(FLAGS.input_file, 'r') as f: + for line in f: + record = json.loads(line) + input_files.append(record[FLAGS.json_input_key]) + target_files.append(record[FLAGS.json_target_key]) + + if FLAGS.input_base_dir != '': + input_files = [ + [os.path.join(FLAGS.input_base_dir, x) for x in y] + for y in input_files + ] + target_files = [ + [os.path.join(FLAGS.input_base_dir, x) for x in y] + for y in target_files + ] + + if FLAGS.max_examples > 0: + input_files = input_files[:FLAGS.max_examples] + target_files = target_files[:FLAGS.max_examples] + + dataset = MultiFrameDataset(input_files, target_files) + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=FLAGS.batch_size * model.n_processes, + shuffle=False, + num_workers=FLAGS.n_workers + ) + + perplexities = [] + + for input_images, target_images in tqdm(data_loader, ncols=0): + perplexity = model.compute_perplexity(input_images, target_images) + perplexities.append(perplexity) + + perplexities = np.concatenate(perplexities, axis=0) + print(f'Perplexity: {np.mean(perplexities)}') + + +if __name__ == "__main__": + mlxu.run(main) \ No newline at end of file diff --git a/eval_video_perplexity.py b/eval_video_perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ba70ee8f53089c9bc456dcf9ca4dda0faade1f --- /dev/null +++ b/eval_video_perplexity.py @@ -0,0 +1,134 @@ + +import os +import glob +from functools import partial +from tqdm import tqdm, trange +from multiprocessing import Pool +from PIL import Image +import cv2 +import mlxu +from natsort import natsorted +import numpy as np +import einops +import torch + +from vqlm_demo.inference import MultiProcessInferenceModel +from vqlm_demo.utils import ( + is_video, random_square_crop, + read_frames_from_dir, read_frames_from_video +) + + +FLAGS, _ = mlxu.define_flags_with_default( + checkpoint='', + input_files='', + frame_input=False, + read_file_list='', + center_crop=1.0, + n_context_frames=15, + n_target_frames=1, + n_workers=8, + stride=8, + batch_size=2, + torch_devices='', + shuffle=False, + random_start=True, + max_examples=0, +) + + +class VideoDataset(torch.utils.data.Dataset): + + def __init__(self, videos, frame_input=False, n_context_frames=15, + n_target_frames=1, stride=1): + self.videos = videos + self.frame_input = frame_input + self.n_context_frames = n_context_frames + self.n_target_frames = n_target_frames + self.stride = stride + + def __getitem__(self, index): + if self.frame_input: + frames = read_frames_from_dir( + self.videos[index], + self.n_context_frames + self.n_target_frames, + self.stride, + center_crop=FLAGS.center_crop, + random_start=FLAGS.random_start, + ) + else: + frames = read_frames_from_video( + self.videos[index], + self.n_context_frames + self.n_target_frames, + self.stride, + center_crop=FLAGS.center_crop, + random_start=FLAGS.random_start, + ) + if frames is None: + return self[np.random.randint(0, len(self))] + return frames[:self.n_context_frames], frames[self.n_context_frames:] + + def __len__(self): + return len(self.videos) + + + +def main(_): + assert FLAGS.checkpoint != '' + assert FLAGS.read_file_list != '' or FLAGS.input_files != '' + + model = MultiProcessInferenceModel( + checkpoint=FLAGS.checkpoint, + torch_devices=FLAGS.torch_devices, + perplexity_batch_size=FLAGS.batch_size, + ) + + if FLAGS.read_file_list != '': + with open(FLAGS.read_file_list, 'r') as f: + videos = [x.strip() for x in f.readlines()] + else: + videos = glob.glob(FLAGS.input_files) + + if FLAGS.frame_input: + videos = [x for x in videos if os.path.isdir(x)] + else: + videos = [x for x in videos if is_video(x)] + + if FLAGS.shuffle: + np.random.shuffle(videos) + + if FLAGS.max_examples > 0: + videos = videos[:FLAGS.max_examples] + + dataset = VideoDataset( + videos, + frame_input=FLAGS.frame_input, + n_context_frames=FLAGS.n_context_frames, + n_target_frames=FLAGS.n_target_frames, + stride=FLAGS.stride + ) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=FLAGS.batch_size * model.n_processes * 4, + shuffle=False, + num_workers=FLAGS.n_workers, + prefetch_factor=4, + drop_last=True, + ) + + perplexities = [] + + for batch_context_frames, batch_taret_frames in tqdm(dataloader, ncols=0): + batch_context_frames = batch_context_frames.numpy() + batch_taret_frames = batch_taret_frames.numpy() + perplexity = model.compute_perplexity( + batch_context_frames, batch_taret_frames + ) + perplexities.append(perplexity) + + perplexities = np.concatenate(perplexities, axis=0) + print(f'Perplexity: {np.mean(perplexities)}') + + +if __name__ == '__main__': + mlxu.run(main) \ No newline at end of file diff --git a/eval_videos.py b/eval_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..4822e54d1de6e622eb5e2ccd6212f12c5130ea7a --- /dev/null +++ b/eval_videos.py @@ -0,0 +1,160 @@ +import os +import glob +from functools import partial +from tqdm import tqdm, trange +from multiprocessing import Pool +from PIL import Image +import cv2 +import mlxu +from natsort import natsorted +import numpy as np +import einops +import torch + +from vqlm_demo.inference import MultiProcessInferenceModel +from vqlm_demo.utils import ( + is_video, random_square_crop, + read_frames_from_dir, read_frames_from_video +) + + +FLAGS, _ = mlxu.define_flags_with_default( + checkpoint='', + input_files='', + frame_input=False, + read_file_list='', + output_dir='', + center_crop=1.0, + n_context_frames=12, + n_new_frames=4, + n_candidates=8, + temperature=1.0, + top_p=1.0, + n_workers=8, + stride=8, + batch_size=32, + torch_devices='', + shuffle=False, + max_examples=0, +) + + +def save_image(args): + image, filename = args + base = FLAGS.input_files.split('*')[0] + filename = filename[len(base):].replace('/', '_') + '.png' + Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename)) + + +class VideoDataset(torch.utils.data.Dataset): + + def __init__(self, videos, frame_input=False, n_frames=8, stride=1, new_frame=1): + self.videos = videos + self.frame_input = frame_input + self.n_frames = n_frames + self.stride = stride + self.new_frames = new_frames + + def __getitem__(self, index): + if self.frame_input: + frames = read_frames_from_dir( + self.videos[index], self.n_frames, self.stride, + center_crop=FLAGS.center_crop, + ) + + else: + # 's h w c' + frames = read_frames_from_video( + self.videos[index], self.n_frames, self.stride, + center_crop=FLAGS.center_crop, + ) + target_frames = frames[n_frames-new_frame:n_frames, :, :, :] + + if frames is None: + return self[np.random.randint(0, len(self))] + + + return frames, target_frames, self.videos[index] + + def __len__(self): + return len(self.videos) + + + +def main(_): + assert FLAGS.checkpoint != '' and FLAGS.output_dir != '' + assert FLAGS.read_file_list != '' or FLAGS.input_files != '' + os.makedirs(FLAGS.output_dir, exist_ok=True) + + if FLAGS.read_file_list != '': + with open(FLAGS.read_file_list, 'r') as f: + videos = [x.strip() for x in f.readlines()] + else: + videos = glob.glob(FLAGS.input_files) + + if FLAGS.frame_input: + videos = [x for x in videos if os.path.isdir(x)] + else: + videos = [x for x in videos if is_video(x)] + + if FLAGS.shuffle: + np.random.shuffle(videos) + + if FLAGS.max_examples > 0: + videos = videos[:FLAGS.max_examples] + + dataset = VideoDataset( + videos, + frame_input=FLAGS.frame_input, + n_frames=FLAGS.n_context_frames, + stride=FLAGS.stride + ) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=FLAGS.batch_size, + shuffle=False, + num_workers=FLAGS.n_workers, + prefetch_factor=4, + drop_last=True, + ) + + if FLAGS.torch_devices == '': + torch_devices = None + else: + torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')] + + model = MultiProcessInferenceModel( + checkpoint=FLAGS.checkpoint, torch_devices=torch_devices, + ) + + save_img_pool = Pool(FLAGS.n_workers) + + + fids + + for batch, batch_targets, filenames in tqdm(dataloader, ncols=0): + + batch = batch.numpy() # 'b s h w c ' + + + + generated = model( + batch, + n_new_frames=FLAGS.n_new_frames, + n_candidates=FLAGS.n_candidates, + temperature=FLAGS.temperature, + top_p=FLAGS.top_p, + ) + + + generated = np.array(generated) + + batch_targets = einops.repeat( + batch_targets.numpy(), + 'b s h w c -> b n s h w c', # batch, candidate, sequence, h, w, c. + n=FLAGS.n_candidates, + ) + + +if __name__ == '__main__': + mlxu.run(main) \ No newline at end of file diff --git a/generate_videos.py b/generate_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..00f6d129adb843cd53723f32e16555a3108cda76 --- /dev/null +++ b/generate_videos.py @@ -0,0 +1,168 @@ + +import os +import glob +from functools import partial +from tqdm import tqdm, trange +from multiprocessing import Pool +from PIL import Image +import cv2 +import mlxu +from natsort import natsorted +import numpy as np +import einops +import torch + +from vqlm_demo.inference import MultiProcessInferenceModel +from vqlm_demo.utils import ( + is_video, random_square_crop, + read_frames_from_dir, read_frames_from_video +) + + +FLAGS, _ = mlxu.define_flags_with_default( + checkpoint='', + input_files='', + frame_input=False, + read_file_list='', + output_dir='', + center_crop=1.0, + n_context_frames=12, + n_new_frames=4, + n_candidates=8, + temperature=1.0, + top_p=1.0, + n_workers=8, + stride=8, + batch_size=32, + torch_devices='', + shuffle=False, + max_examples=0, +) + + +def save_image(args): + image, filename = args + base = FLAGS.input_files.split('*')[0] + filename = filename[len(base):].replace('/', '_') + '.png' + Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename)) + + +class VideoDataset(torch.utils.data.Dataset): + + def __init__(self, videos, frame_input=False, n_frames=8, stride=1): + self.videos = videos + self.frame_input = frame_input + self.n_frames = n_frames + self.stride = stride + + def __getitem__(self, index): + if self.frame_input: + frames = read_frames_from_dir( + self.videos[index], self.n_frames, self.stride, + center_crop=FLAGS.center_crop, + ) + else: + frames = read_frames_from_video( + self.videos[index], self.n_frames, self.stride, + center_crop=FLAGS.center_crop, + ) + if frames is None: + return self[np.random.randint(0, len(self))] + return frames, self.videos[index] + + def __len__(self): + return len(self.videos) + + + +def main(_): + assert FLAGS.checkpoint != '' and FLAGS.output_dir != '' + assert FLAGS.read_file_list != '' or FLAGS.input_files != '' + os.makedirs(FLAGS.output_dir, exist_ok=True) + + if FLAGS.read_file_list != '': + with open(FLAGS.read_file_list, 'r') as f: + videos = [x.strip() for x in f.readlines()] + else: + videos = glob.glob(FLAGS.input_files) + + if FLAGS.frame_input: + videos = [x for x in videos if os.path.isdir(x)] + else: + videos = [x for x in videos if is_video(x)] + + if FLAGS.shuffle: + np.random.shuffle(videos) + + if FLAGS.max_examples > 0: + videos = videos[:FLAGS.max_examples] + + dataset = VideoDataset( + videos, + frame_input=FLAGS.frame_input, + n_frames=FLAGS.n_context_frames, + stride=FLAGS.stride + ) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=FLAGS.batch_size, + shuffle=False, + num_workers=FLAGS.n_workers, + prefetch_factor=4, + drop_last=True, + ) + + if FLAGS.torch_devices == '': + torch_devices = None + else: + torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')] + + model = MultiProcessInferenceModel( + checkpoint=FLAGS.checkpoint, torch_devices=torch_devices, + ) + + save_img_pool = Pool(FLAGS.n_workers) + + + + for batch, filenames in tqdm(dataloader, ncols=0): + + + + batch = batch.numpy() + + + + generated = model( + batch, + n_new_frames=FLAGS.n_new_frames, + n_candidates=FLAGS.n_candidates, + temperature=FLAGS.temperature, + top_p=FLAGS.top_p, + ) + + + generated = np.array(generated) + + + + + output_batch = einops.repeat( + batch, + 'b s h w c -> b n s h w c', + n=FLAGS.n_candidates, + ) + + + combined = einops.rearrange( + np.concatenate([output_batch, generated], axis=2), + 'b n s h w c -> b (n h) (s w) c' + ) + + + combined = (np.clip(combined, 0, 1) * 255).astype(np.uint8) + save_img_pool.imap(save_image, zip(combined, filenames)) + + +if __name__ == '__main__': + mlxu.run(main) \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..73b1b816a632da1e9fa2ec70b608c1d04fc54e1a --- /dev/null +++ b/inference.py @@ -0,0 +1,240 @@ +from abc import ABC, abstractmethod +from contextlib import nullcontext +import time +import os +from functools import partial +from copy import deepcopy +from multiprocessing import Pool +from threading import Lock +from PIL import Image +import numpy as np +import torch +import torch.nn.functional as F +import einops +from transformers import LlamaForCausalLM +import spaces + +from vqvae_muse import VQGANModel, get_tokenizer_muse +from torch_vqvae_model import get_tokenizer + + +def get_torch_float_dtype(dtype): + if dtype in (torch.float16, torch.bfloat16, torch.float32): + return dtype + return { + 'float16': torch.float16, + 'fp16': torch.float16, + 'f16': torch.float16, + 'bfloat16': torch.bfloat16, + 'bf16': torch.bfloat16, + 'float32': torch.float32, + 'fp32': torch.float32, + 'f32': torch.float32, + }[dtype] + + +def get_pid(): + time.sleep(1) + return os.getpid() + + +class InferenceModel(ABC): + + @abstractmethod + def __call__(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0): + raise NotImplementedError() + + +class LocalInferenceModel(InferenceModel): + + def __init__(self, checkpoint, dtype='float16', torch_device='cuda', + context_frames=16, use_lock=False): + self.checkpoint = checkpoint + self.dtype = dtype + self.torch_device = torch_device + self.context_frames = context_frames + + # new tokenizer + self.tokenizer = get_tokenizer_muse() + self.tokenizer.to(self.torch_device) + + self.model = LlamaForCausalLM.from_pretrained( + self.checkpoint, torch_dtype=get_torch_float_dtype(self.dtype) + ).to(self.torch_device) + print("torch device", self.torch_device) + print("init device", self.model.device) + + if use_lock: + self.lock = Lock() + else: + self.lock = nullcontext() + + @torch.no_grad() + def compute_perplexity(self, input_images, target_images): + input_images = np.array(input_images) + target_images = np.array(target_images) + assert len(input_images.shape) == 5 and len(target_images.shape) == 5 # [B, S, H, W, C] + assert input_images.shape[0] == target_images.shape[0] + batch_size = input_images.shape[0] + with self.lock: + input_images = torch.tensor( + einops.rearrange(input_images, 'b s h w c -> b s c h w') + ).to(self.torch_device) + target_images = torch.tensor( + einops.rearrange(target_images, 'b s h w c -> b s c h w') + ).to(self.torch_device) + input_ids = self.tokenizer.tokenize(input_images).view(batch_size, -1) + target_ids = self.tokenizer.tokenize(target_images).view(batch_size, -1) + all_ids = torch.cat([input_ids, target_ids], dim=1) + logits = self.model(all_ids).logits + log_probs = F.log_softmax(logits, dim=-1) + target_ids_onehot = F.one_hot(target_ids, num_classes=logits.shape[-1]) + target_log_probs = log_probs[:, input_ids.shape[1] - 1 : -1] + perplexity = torch.exp( + -torch.mean( + torch.sum(target_log_probs * target_ids_onehot, dim=-1), + dim=-1 + ) + ) + return perplexity.detach().cpu().numpy() + + @torch.no_grad() + def generate_once(self, input_images, n_new_frames, temperature=1.0, top_p=1.0): + assert type(input_images) == np.ndarray + with self.lock: + input_images = np.array(input_images, dtype=np.float32) + input_images = torch.tensor( + einops.rearrange(input_images, 'b h w c -> b c h w') + ).to(self.torch_device) + + # not quite sure why i need to redo it here + self.model.to(self.torch_device) + self.tokenizer.to(self.torch_device) + + # new tokenizer + _, input_ids = self.tokenizer.encode(input_images) + input_ids = input_ids.view(1, -1) + + + input_ids = input_ids[:, -(self.context_frames - 1) * 256:] + + new_tokens = [] + current_context_frames = input_ids.shape[1] // 256 + fisrt_generation_left = self.context_frames - current_context_frames + first_new_frames = min(fisrt_generation_left, n_new_frames) + input_ids = self.model.generate( + input_ids=input_ids, + attention_mask=torch.ones_like(input_ids), + pad_token_id=8192, + max_new_tokens=256 * first_new_frames, + do_sample=True, + top_p=top_p, + temperature=temperature, + suppress_tokens=list(range(8192, self.model.vocab_size)), + ) + new_tokens.append(input_ids[:, -256 * first_new_frames:]) + input_ids = input_ids[:, -(self.context_frames - 1) * 256:] + + for _ in range(max(0, n_new_frames - first_new_frames)): + input_ids = self.model.generate( + input_ids=input_ids, + attention_mask=torch.ones_like(input_ids), + pad_token_id=8192, + max_new_tokens=256, + do_sample=True, + top_p=top_p, + temperature=temperature, + suppress_tokens=list(range(8192, self.model.vocab_size)), + ) + new_tokens.append(input_ids[:, -256:]) + input_ids = input_ids[:, -(self.context_frames - 1) * 256:] + + new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256) + new_images = einops.rearrange( + torch.clamp(self.tokenizer.decode_code(new_tokens), 0.0, 1.0), + 'b c h w -> b h w c' + ).detach().cpu().numpy() + return new_images + + @spaces.GPU(duration=180) + def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0): + output = [] + for seq in input_images: + output.append( + [self.generate_once(seq, n_new_frames, temperature, top_p) + for _ in range(n_candidates)] + ) + return output + + +class MultiProcessInferenceModel(InferenceModel): + + def __init__(self, checkpoint, torch_devices=None, dtype='float16', + context_frames=16, use_lock=False, perplexity_batch_size=2): + if torch_devices is None or torch_devices == '': + torch_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())] + + self.torch_devices = torch_devices + self.n_processes = len(torch_devices) + print(f'Using {self.n_processes} processes for inference') + self.worker_pool = Pool(self.n_processes) + self.worker_pids = self.worker_pool.starmap(get_pid, [tuple() for _ in range(self.n_processes)]) + self.device_map = { + pid: torch_device + for pid, torch_device in zip(self.worker_pids, self.torch_devices) + } + self.worker_pool.starmap( + self.initialize_worker, + [(self.device_map, checkpoint, dtype, context_frames) for _ in range(self.n_processes)] + ) + self.perplexity_batch_size = perplexity_batch_size + if use_lock: + self.lock = Lock() + else: + self.lock = nullcontext() + + @staticmethod + def initialize_worker(device_map, checkpoint, dtype, context_frames): + global _current_process_backend + torch_device = device_map[os.getpid()] + _current_process_backend = LocalInferenceModel( + checkpoint, dtype, torch_device, context_frames + ) + + @staticmethod + def generate_once(input_images, n_new_frames, temperature=1.0, top_p=1.0): + return _current_process_backend.generate_once(input_images, n_new_frames, temperature, top_p) + + @staticmethod + def compute_perplexity_once(input_images, target_images): + return _current_process_backend.compute_perplexity(input_images, target_images) + + def compute_perplexity(self, input_images, target_images): + with self.lock: + map_args = [] + for i in range(0, len(input_images), self.perplexity_batch_size): + map_args.append(( + input_images[i : i + self.perplexity_batch_size], + target_images[i : i + self.perplexity_batch_size] + )) + outputs = self.worker_pool.starmap(self.compute_perplexity_once, map_args) + return np.concatenate(outputs, axis=0) + + def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0): + with self.lock: + map_args = [] + for seq in input_images: + for _ in range(n_candidates): + map_args.append((seq, n_new_frames, temperature, top_p)) + + outputs = self.worker_pool.starmap(self.generate_once, map_args) + reshaped_output = [] + index = 0 + for _ in range(len(input_images)): + candidates = [] + for _ in range(n_candidates): + candidates.append(outputs[index]) + index += 1 + reshaped_output.append(candidates) + return reshaped_output + diff --git a/prompts/.DS_Store b/prompts/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..3b3b755588b8880f623b777ccfaf36048c71b851 Binary files /dev/null and b/prompts/.DS_Store differ diff --git a/prompts/Composition/Slide1.png b/prompts/Composition/Slide1.png new file mode 100644 index 0000000000000000000000000000000000000000..c20dd2a265aef200d02fec3495af8cdb4fece30d --- /dev/null +++ b/prompts/Composition/Slide1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d926922e8e28f02c46e723b85d3d4969da271f892654ce492cf59cbf3f322a0 +size 194501 diff --git a/prompts/Composition/Slide10.png b/prompts/Composition/Slide10.png new file mode 100644 index 0000000000000000000000000000000000000000..c8a0e88cbc36c13d575f223edd6681fe95f63a86 --- /dev/null +++ b/prompts/Composition/Slide10.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4480a3a7905b703ca1802e5391ea47e90e84cdc7eacb5229ade606ce4f5b6bb +size 443693 diff --git a/prompts/Composition/Slide11.png b/prompts/Composition/Slide11.png new file mode 100644 index 0000000000000000000000000000000000000000..0549fe330577e0adc697dc03fb284aa15b14f441 --- /dev/null +++ b/prompts/Composition/Slide11.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91cbe861bd47c4ec08e79bccdb64b993cc4b3b21549c346f834a985b1b0a1a6e +size 464548 diff --git a/prompts/Composition/Slide12.png b/prompts/Composition/Slide12.png new file mode 100644 index 0000000000000000000000000000000000000000..31116cc53413933baaaf93aa5c7a4373e713944d --- /dev/null +++ b/prompts/Composition/Slide12.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d05d2db2a5e7bc7e33795583e10cdc03ea53bacd250010680a161ab07b7ad65 +size 487835 diff --git a/prompts/Composition/Slide13.png b/prompts/Composition/Slide13.png new file mode 100644 index 0000000000000000000000000000000000000000..f5c89f17f384cb855047917b3bdc589919cd4504 --- /dev/null +++ b/prompts/Composition/Slide13.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d94cfad17df77fa90ab84bdd89d3ad09938a5fe768b4e211c2bac140b36c12cb +size 489967 diff --git a/prompts/Composition/Slide14.png b/prompts/Composition/Slide14.png new file mode 100644 index 0000000000000000000000000000000000000000..de90d3fa3d4b3af9d32fbce6803389b072d26322 --- /dev/null +++ b/prompts/Composition/Slide14.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04b42409ec1ca2ddbde1114eb8426a34c5e0064159e224af808b766ae003d2fd +size 492423 diff --git a/prompts/Composition/Slide15.png b/prompts/Composition/Slide15.png new file mode 100644 index 0000000000000000000000000000000000000000..871bd579690ad66a0c714a1e5fdc33846aa9147c --- /dev/null +++ b/prompts/Composition/Slide15.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9919156ccdd9c2cbb30811529e94c83bb2afb277c90fed503ea4716be702cdde +size 491891 diff --git a/prompts/Composition/Slide2.png b/prompts/Composition/Slide2.png new file mode 100644 index 0000000000000000000000000000000000000000..dbe072dca6d9ef3fe4491d740af5df5c6b010c68 --- /dev/null +++ b/prompts/Composition/Slide2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0c9f6467cc732b562c167770d38a164162e4454127a242a16e3bdae7e717d27 +size 193143 diff --git a/prompts/Composition/Slide3.png b/prompts/Composition/Slide3.png new file mode 100644 index 0000000000000000000000000000000000000000..3e8ac80de12ba6e1b5c4bf4e6bfa6ac7ebad7ad6 --- /dev/null +++ b/prompts/Composition/Slide3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f702f10001fd9e7ad523753c884f8cef532da878d62656ffdbd566e104b67c7 +size 199394 diff --git a/prompts/Composition/Slide4.png b/prompts/Composition/Slide4.png new file mode 100644 index 0000000000000000000000000000000000000000..7e0643b567bf4a0181ff14b6b954356c30ad7b06 --- /dev/null +++ b/prompts/Composition/Slide4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18c2d2384e4c97f35ae4cddc9bea4e600946eefefefff1f4fb683a51a54d4384 +size 202638 diff --git a/prompts/Composition/Slide5.png b/prompts/Composition/Slide5.png new file mode 100644 index 0000000000000000000000000000000000000000..59c9ae2435567cdce73774b3ef5342d70a0f13da --- /dev/null +++ b/prompts/Composition/Slide5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4af292b97a2abe48d253fb2f1badd8d147402a3124fd12a2a0750307487c4f27 +size 190546 diff --git a/prompts/Composition/Slide6.png b/prompts/Composition/Slide6.png new file mode 100644 index 0000000000000000000000000000000000000000..fb2d05758aadbd0ee184bc274ece6ede5714dd6c --- /dev/null +++ b/prompts/Composition/Slide6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8b5fe9521e4094950384fce57733496363750b6a7c816ebae3cf43e6bcdb626 +size 173097 diff --git a/prompts/Composition/Slide7.png b/prompts/Composition/Slide7.png new file mode 100644 index 0000000000000000000000000000000000000000..aa5e53aa0367d64009edede3023ab3ab1cdfa196 --- /dev/null +++ b/prompts/Composition/Slide7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce9418614363adfcd1b96b6df3b990d8204ed0a0341c348f9f340d7c128b4900 +size 174070 diff --git a/prompts/Composition/Slide8.png b/prompts/Composition/Slide8.png new file mode 100644 index 0000000000000000000000000000000000000000..d214f14a6149fd5fdf7a9a55b558b2e74b192359 --- /dev/null +++ b/prompts/Composition/Slide8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66a589f649600c65b7e808d824322d5a7c36b39675704cd5857fc31ce4f5af7f +size 180144 diff --git a/prompts/Composition/Slide9.png b/prompts/Composition/Slide9.png new file mode 100644 index 0000000000000000000000000000000000000000..c2be5efa234c711a7da20cd11f717e797b1e9bf8 --- /dev/null +++ b/prompts/Composition/Slide9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee514628d3da8c4525853c86d1e8d348de7a0641312a0e3c79fae6b5d73ae11f +size 454702 diff --git a/prompts/Depth Estimation/1.png b/prompts/Depth Estimation/1.png new file mode 100644 index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62 --- /dev/null +++ b/prompts/Depth Estimation/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2 +size 48533 diff --git a/prompts/Depth Estimation/1_depth.png b/prompts/Depth Estimation/1_depth.png new file mode 100644 index 0000000000000000000000000000000000000000..e4a416e3cf5ee9e7b46f95c5590905192d50f99f --- /dev/null +++ b/prompts/Depth Estimation/1_depth.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b22aa119576ab691bab3db3fdd7eacf53dadc9e4cb3a9bfe4f4cb9c6fc0f6c6 +size 13888 diff --git a/prompts/Depth Estimation/2.png b/prompts/Depth Estimation/2.png new file mode 100644 index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5 --- /dev/null +++ b/prompts/Depth Estimation/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b +size 54286 diff --git a/prompts/Depth Estimation/2_depth.png b/prompts/Depth Estimation/2_depth.png new file mode 100644 index 0000000000000000000000000000000000000000..70e9ccb6b44f0284cac2e39bc3e9e4981d5ac373 --- /dev/null +++ b/prompts/Depth Estimation/2_depth.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37eacaf9208cf21693ae99802697e8894a9e8cf40cc221c704a50358f14dc954 +size 12257 diff --git a/prompts/Depth Estimation/3.png b/prompts/Depth Estimation/3.png new file mode 100644 index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18 --- /dev/null +++ b/prompts/Depth Estimation/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00 +size 52593 diff --git a/prompts/Depth Estimation/3_depth.png b/prompts/Depth Estimation/3_depth.png new file mode 100644 index 0000000000000000000000000000000000000000..f39d22d9069072ee46affd295b14f8e02a70ceb6 --- /dev/null +++ b/prompts/Depth Estimation/3_depth.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a91a47d21378bef0d535f7e33c0185e60cc23baa7ced20bc1ffb028a5d95b5c4 +size 13332 diff --git a/prompts/Depth Estimation/4.png b/prompts/Depth Estimation/4.png new file mode 100644 index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a --- /dev/null +++ b/prompts/Depth Estimation/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d +size 60589 diff --git a/prompts/Depth Estimation/4_depth.png b/prompts/Depth Estimation/4_depth.png new file mode 100644 index 0000000000000000000000000000000000000000..d578e93b900f2e0d74ca15723de3c92b90fc66d2 --- /dev/null +++ b/prompts/Depth Estimation/4_depth.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0685d0c4755206910cb1b1feea54a1e843cdee9dd140e414c0df56a885b68d85 +size 13447 diff --git a/prompts/Depth Estimation/5.png b/prompts/Depth Estimation/5.png new file mode 100644 index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd --- /dev/null +++ b/prompts/Depth Estimation/5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b +size 21984 diff --git a/prompts/Depth Estimation/5_depth.png b/prompts/Depth Estimation/5_depth.png new file mode 100644 index 0000000000000000000000000000000000000000..9790c23cff31b23afe755370fae8c4fbac6316f3 --- /dev/null +++ b/prompts/Depth Estimation/5_depth.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e9874362d8a1c85b0030590399a5f6388fe69b9dd42ec762313b97d37817eb7 +size 12020 diff --git a/prompts/Depth Estimation/6.png b/prompts/Depth Estimation/6.png new file mode 100644 index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252 --- /dev/null +++ b/prompts/Depth Estimation/6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547 +size 30704 diff --git a/prompts/Depth Estimation/6_depth.png b/prompts/Depth Estimation/6_depth.png new file mode 100644 index 0000000000000000000000000000000000000000..4f84aeef983f2a70a1a26033942ffbc1247eeaa5 --- /dev/null +++ b/prompts/Depth Estimation/6_depth.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:588df6be45864d2164b5e215429f268ba82731540adb07cd4ea47db0ca8f5319 +size 11946 diff --git a/prompts/Depth Estimation/7.png b/prompts/Depth Estimation/7.png new file mode 100644 index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6 --- /dev/null +++ b/prompts/Depth Estimation/7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7 +size 49450 diff --git a/prompts/Depth Estimation/7_depth.png b/prompts/Depth Estimation/7_depth.png new file mode 100644 index 0000000000000000000000000000000000000000..fe14a639f330ac40b0ec82a7c19e8f61befe9496 --- /dev/null +++ b/prompts/Depth Estimation/7_depth.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0edcab180411a6966d899de8f282870293a5d275e58d4a185e3cb31d9ca6b0d +size 13252 diff --git a/prompts/Depth Estimation/8.png b/prompts/Depth Estimation/8.png new file mode 100644 index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d --- /dev/null +++ b/prompts/Depth Estimation/8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38 +size 50877 diff --git a/prompts/Eaten Apples/1.png b/prompts/Eaten Apples/1.png new file mode 100644 index 0000000000000000000000000000000000000000..d32aa001b22d9a408645e06cc02ad52505e230a2 --- /dev/null +++ b/prompts/Eaten Apples/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a75364bb67ce5741004e2bb18178b362fd1d4dee12a76d9ae4be2124fb3452a0 +size 199368 diff --git a/prompts/Eaten Apples/10.png b/prompts/Eaten Apples/10.png new file mode 100644 index 0000000000000000000000000000000000000000..aa963c0a65d4e7605efbabc10244c171172835be --- /dev/null +++ b/prompts/Eaten Apples/10.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05f9235b7c283915d0d81b2423915f05f587b91d691ae0cae6f0bc5b68e84588 +size 142649 diff --git a/prompts/Eaten Apples/2.png b/prompts/Eaten Apples/2.png new file mode 100644 index 0000000000000000000000000000000000000000..dadf5d493d454a5ec31696919aa2be33eea1d6ab --- /dev/null +++ b/prompts/Eaten Apples/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25b08de2de0ac2bcc59be060bf19574931091c9dc6472f8122f7ac1243c59c6f +size 214103 diff --git a/prompts/Eaten Apples/3.png b/prompts/Eaten Apples/3.png new file mode 100644 index 0000000000000000000000000000000000000000..17bb70919745f0f0401140da26a4aa67c9224875 --- /dev/null +++ b/prompts/Eaten Apples/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eee49d97068ac9de19bf6aead9a4b0c88ba9108bc9eb9d19f43a3b5919c88367 +size 212059 diff --git a/prompts/Eaten Apples/4.png b/prompts/Eaten Apples/4.png new file mode 100644 index 0000000000000000000000000000000000000000..bc6979b7ca710ff50f9af0f26f6089d9a25a0b53 --- /dev/null +++ b/prompts/Eaten Apples/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88d401e4c2c2b1b21119b953e230a276305af15d295d0035a221e498665af5b4 +size 212147 diff --git a/prompts/Eaten Apples/5.png b/prompts/Eaten Apples/5.png new file mode 100644 index 0000000000000000000000000000000000000000..60e4028c1752e1caf2581949529b7c76db2b788b --- /dev/null +++ b/prompts/Eaten Apples/5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:116e8eb9ecc170c4a00f54a3b7b8996b67cd585932e34f4e2a25f8e589b7ae3d +size 204197 diff --git a/prompts/Eaten Apples/6.png b/prompts/Eaten Apples/6.png new file mode 100644 index 0000000000000000000000000000000000000000..3011620226f9165f918dd3e1120a3569e4ef3bfd --- /dev/null +++ b/prompts/Eaten Apples/6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c683df4c90c7da5bee98499fbf7233b6ac13fe2480aa9e1d4cb80a25ff9a500 +size 192756 diff --git a/prompts/Eaten Apples/7.png b/prompts/Eaten Apples/7.png new file mode 100644 index 0000000000000000000000000000000000000000..c890af0c7b3b95c297ff07db0da4134754de58b4 --- /dev/null +++ b/prompts/Eaten Apples/7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e3ad84f16c9326a9819da7e2c9485705b073a47097f742592ada91c10f706c0 +size 181082 diff --git a/prompts/Eaten Apples/8.png b/prompts/Eaten Apples/8.png new file mode 100644 index 0000000000000000000000000000000000000000..89545aa186079020196c335da0200c2a1c88c80a --- /dev/null +++ b/prompts/Eaten Apples/8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:919a814cf9c604923a3384c26473750d353bac17f3422865b68c1d86e45552f7 +size 167449 diff --git a/prompts/Edge Detection/1.png b/prompts/Edge Detection/1.png new file mode 100644 index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62 --- /dev/null +++ b/prompts/Edge Detection/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2 +size 48533 diff --git a/prompts/Edge Detection/1_edge.png b/prompts/Edge Detection/1_edge.png new file mode 100644 index 0000000000000000000000000000000000000000..48f7dee44046eb0e7853d81be37870da7f7578fb --- /dev/null +++ b/prompts/Edge Detection/1_edge.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ca5651653eb5ed8e3d08600c936985982f5019ce6f2f2489e82c112ea75686a +size 30563 diff --git a/prompts/Edge Detection/2.png b/prompts/Edge Detection/2.png new file mode 100644 index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5 --- /dev/null +++ b/prompts/Edge Detection/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b +size 54286 diff --git a/prompts/Edge Detection/2_edge.png b/prompts/Edge Detection/2_edge.png new file mode 100644 index 0000000000000000000000000000000000000000..49db076ecece6c9db3bee69a720b8207fadfd8e5 --- /dev/null +++ b/prompts/Edge Detection/2_edge.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8a5327db1e20457b8ab478a296e5176f2779c5e4bb2c034e5b6f0183854866b +size 30437 diff --git a/prompts/Edge Detection/3.png b/prompts/Edge Detection/3.png new file mode 100644 index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18 --- /dev/null +++ b/prompts/Edge Detection/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00 +size 52593 diff --git a/prompts/Edge Detection/3_edge.png b/prompts/Edge Detection/3_edge.png new file mode 100644 index 0000000000000000000000000000000000000000..4f234ff898f4fc9f390bc4d3e5fd114662411b51 --- /dev/null +++ b/prompts/Edge Detection/3_edge.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d12107a3ead99860b853db28ddd7c8cc3fe65966cdb9221f6afd2154dadeb507 +size 32196 diff --git a/prompts/Edge Detection/4.png b/prompts/Edge Detection/4.png new file mode 100644 index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a --- /dev/null +++ b/prompts/Edge Detection/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d +size 60589 diff --git a/prompts/Edge Detection/4_edge.png b/prompts/Edge Detection/4_edge.png new file mode 100644 index 0000000000000000000000000000000000000000..171d43defccfc3a9473faff804b31564483380d6 --- /dev/null +++ b/prompts/Edge Detection/4_edge.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a4d3730f25b2b7305dfc56875b12135b99c0a93cba8ac1a1ff899b7d68eb8ef +size 39602 diff --git a/prompts/Edge Detection/5.png b/prompts/Edge Detection/5.png new file mode 100644 index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd --- /dev/null +++ b/prompts/Edge Detection/5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b +size 21984 diff --git a/prompts/Edge Detection/5_edge.png b/prompts/Edge Detection/5_edge.png new file mode 100644 index 0000000000000000000000000000000000000000..c2b7f3b15d3e3b4aab01eabebccb8556531d7619 --- /dev/null +++ b/prompts/Edge Detection/5_edge.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c958bf6a35c685e34ae07216dd112ac4fcedd9c4629d096615333bb69b45d45c +size 16448 diff --git a/prompts/Edge Detection/6.png b/prompts/Edge Detection/6.png new file mode 100644 index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252 --- /dev/null +++ b/prompts/Edge Detection/6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547 +size 30704 diff --git a/prompts/Edge Detection/6_edge.png b/prompts/Edge Detection/6_edge.png new file mode 100644 index 0000000000000000000000000000000000000000..9d62a8aa115b5cde4f7a317a744f96673102a8c4 --- /dev/null +++ b/prompts/Edge Detection/6_edge.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:75383c3fb8acb0318fb95c48c43b56df8574ca119e0f9f577066dab8fdb8fca3 +size 36706 diff --git a/prompts/Edge Detection/7.png b/prompts/Edge Detection/7.png new file mode 100644 index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6 --- /dev/null +++ b/prompts/Edge Detection/7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7 +size 49450 diff --git a/prompts/Edge Detection/7_edge.png b/prompts/Edge Detection/7_edge.png new file mode 100644 index 0000000000000000000000000000000000000000..be7cf97bb9f5f33d331ece657a7f23817bccf235 --- /dev/null +++ b/prompts/Edge Detection/7_edge.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93a646f6a4fcd48b949aba04dbdeb781b01e0425a32a03d960e7c9617375fe90 +size 29210 diff --git a/prompts/Edge Detection/8.png b/prompts/Edge Detection/8.png new file mode 100644 index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d --- /dev/null +++ b/prompts/Edge Detection/8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38 +size 50877 diff --git a/prompts/Emoji/smile1.png b/prompts/Emoji/smile1.png new file mode 100644 index 0000000000000000000000000000000000000000..069348959a22a84e5993121e89cf401fa576c923 --- /dev/null +++ b/prompts/Emoji/smile1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8923aa6be860685271dbef2fcc1fc87230dec0dca44d0815f9ccdf1f8d5aea26 +size 21247 diff --git a/prompts/Emoji/smile2.png b/prompts/Emoji/smile2.png new file mode 100644 index 0000000000000000000000000000000000000000..01948d4b355fdccce546cd367554db3724231938 --- /dev/null +++ b/prompts/Emoji/smile2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:beadf44be4d267c8f88d684bfd6b02ca13eeff5fb007f2fbfc8e2f52b8459c64 +size 22703 diff --git a/prompts/Emoji/smile3.png b/prompts/Emoji/smile3.png new file mode 100644 index 0000000000000000000000000000000000000000..a527c61b51f40daa4ad8425342a3918f140143bf --- /dev/null +++ b/prompts/Emoji/smile3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:919155477aa8d2d17502cacf0dbdcec8b8feba5faa2065e86b832ac6020e2169 +size 24308 diff --git a/prompts/Emoji/smile4.png b/prompts/Emoji/smile4.png new file mode 100644 index 0000000000000000000000000000000000000000..cc8f33154f294ae7c6141c6df333645f8e314f80 --- /dev/null +++ b/prompts/Emoji/smile4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9571b1c938cd2ccb522bdd15cddb4ff93b84c6ac938923a34f1ee01fdb2a002a +size 24451 diff --git a/prompts/Object Tracking/Picture1.png b/prompts/Object Tracking/Picture1.png new file mode 100644 index 0000000000000000000000000000000000000000..9ee9959fc01d71202482304f5031bf3cb7fa04db --- /dev/null +++ b/prompts/Object Tracking/Picture1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d311a18e7a9c934ce775fa6db6576f9ed27b9fcac82f224626af8b5b074b3c35 +size 733163 diff --git a/prompts/Object Tracking/Picture2.png b/prompts/Object Tracking/Picture2.png new file mode 100644 index 0000000000000000000000000000000000000000..5f23477cb63d62c8b57fe69360905c617bde0155 --- /dev/null +++ b/prompts/Object Tracking/Picture2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52a2437f5bd694d69a384c82ad2b163eb665b50f2814a4a30cc1719e4436393f +size 730944 diff --git a/prompts/Object Tracking/Picture3.png b/prompts/Object Tracking/Picture3.png new file mode 100644 index 0000000000000000000000000000000000000000..580685c71839bb687193ff48f07300fb423ed80b --- /dev/null +++ b/prompts/Object Tracking/Picture3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d31cdf20c257094e3624cc032c975a5326d904d53664ed7acfd022c829961e96 +size 723189 diff --git a/prompts/Object Tracking/Picture4.png b/prompts/Object Tracking/Picture4.png new file mode 100644 index 0000000000000000000000000000000000000000..85b9aa885f9187d6b7124a04356874280a099caf --- /dev/null +++ b/prompts/Object Tracking/Picture4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb5e559a8421bd48bcdb2689142b52ad3384a49d8b59c8f00fb65ab583c62cb1 +size 709894 diff --git a/prompts/Object Tracking/Picture5.png b/prompts/Object Tracking/Picture5.png new file mode 100644 index 0000000000000000000000000000000000000000..f8b0f91b49529756511fd094ced04288c2636137 --- /dev/null +++ b/prompts/Object Tracking/Picture5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e486056f0c6898326b54410878f1c0deeba2a5c5db7acec9ba5a14617ab59068 +size 690117 diff --git a/prompts/Object Tracking/Picture6.png b/prompts/Object Tracking/Picture6.png new file mode 100644 index 0000000000000000000000000000000000000000..10e56c3f9db3a014eef054561236356c444d84c3 --- /dev/null +++ b/prompts/Object Tracking/Picture6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf16762c2dd316e52721618f96bcc0811118f812427e933fef50f937adaea3d4 +size 671387 diff --git a/prompts/Object Tracking/Picture7.png b/prompts/Object Tracking/Picture7.png new file mode 100644 index 0000000000000000000000000000000000000000..c41b3438aa20a8a8f6460003206c945100658073 --- /dev/null +++ b/prompts/Object Tracking/Picture7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:227bf301dc7f06208e897ab3b6092e15e0b32f7ca0ef77a51f8fb81bd258e9a3 +size 654387 diff --git a/prompts/Object Tracking/Picture8.png b/prompts/Object Tracking/Picture8.png new file mode 100644 index 0000000000000000000000000000000000000000..6de5ddbfcfa353d01d176426271f2fca2c770961 --- /dev/null +++ b/prompts/Object Tracking/Picture8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e38170d147a836c88782b37d7f2bd60ee051016c6daf53ea879320f0967c818c +size 642514 diff --git a/prompts/Object Tracking/Picture9.png b/prompts/Object Tracking/Picture9.png new file mode 100644 index 0000000000000000000000000000000000000000..117698ee19e804cc082023da4dff47427f68532f --- /dev/null +++ b/prompts/Object Tracking/Picture9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14dc7b90ae871b5be240e5775e1bd0be2ae18e73ef5ac87e7a96a928613c03b5 +size 622249 diff --git a/prompts/Outpainting/2.png b/prompts/Outpainting/2.png new file mode 100644 index 0000000000000000000000000000000000000000..b77b519a6a263166ce81d41efc2a17e528c5dafc --- /dev/null +++ b/prompts/Outpainting/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d73f108e47033bf08b8f5b4558e72c73231f5cb06e06c6283b782fe260739f1 +size 440489 diff --git a/prompts/Outpainting/3.png b/prompts/Outpainting/3.png new file mode 100644 index 0000000000000000000000000000000000000000..3e2d1fdfecf4447a2475ed09e23035f11d161e95 --- /dev/null +++ b/prompts/Outpainting/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fba02eef20fd3a32ff01fc43bae60811185ce34cac343699a4e2b7383a81e8ef +size 829260 diff --git a/prompts/Outpainting/4.png b/prompts/Outpainting/4.png new file mode 100644 index 0000000000000000000000000000000000000000..0375a22e663f9855153362011d827ee04dda1011 --- /dev/null +++ b/prompts/Outpainting/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e36507bd52fc628b911f8e3d2806549344fa51a41821bb8a31b772296c5c8207 +size 1295391 diff --git a/prompts/Outpainting/5.png b/prompts/Outpainting/5.png new file mode 100644 index 0000000000000000000000000000000000000000..48e5357ccdadd953ecbeb7e45c832f44f063f9d1 --- /dev/null +++ b/prompts/Outpainting/5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42b3a730e6ad1adf868aac31dfadbd2773df5850b3602dcb84e4c10b57e3f4a1 +size 1980696 diff --git a/prompts/Segmentation/1.png b/prompts/Segmentation/1.png new file mode 100644 index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62 --- /dev/null +++ b/prompts/Segmentation/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2 +size 48533 diff --git a/prompts/Segmentation/1_seg.png b/prompts/Segmentation/1_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..7ea7e500d9d2dc981deb196fb5693acf96459636 --- /dev/null +++ b/prompts/Segmentation/1_seg.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2642904a2acd1bbf4cc099ccf3bb0d0f7ee3370079eda0a21d04df75923a2dd +size 1234 diff --git a/prompts/Segmentation/2.png b/prompts/Segmentation/2.png new file mode 100644 index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5 --- /dev/null +++ b/prompts/Segmentation/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b +size 54286 diff --git a/prompts/Segmentation/2_seg.png b/prompts/Segmentation/2_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..a38cbb9400b37bcda2aa0d92c022a2883512728b --- /dev/null +++ b/prompts/Segmentation/2_seg.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1070ec66a0b1295a20871f0ef04e2f150e708b8fedcd915733b62978cbfcd0d +size 2243 diff --git a/prompts/Segmentation/3.png b/prompts/Segmentation/3.png new file mode 100644 index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18 --- /dev/null +++ b/prompts/Segmentation/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00 +size 52593 diff --git a/prompts/Segmentation/3_seg.png b/prompts/Segmentation/3_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..4f3645ca6724fce3c749631221a657ae42dba25c --- /dev/null +++ b/prompts/Segmentation/3_seg.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf8dc0e197459bdef57d3246e3cd5355f7a8c6d5f4822eb05fe1a8813d3f71b5 +size 1869 diff --git a/prompts/Segmentation/4.png b/prompts/Segmentation/4.png new file mode 100644 index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a --- /dev/null +++ b/prompts/Segmentation/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d +size 60589 diff --git a/prompts/Segmentation/4_seg.png b/prompts/Segmentation/4_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..358cb0caa8e26115253f7e579484ec207ecf11bf --- /dev/null +++ b/prompts/Segmentation/4_seg.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:101793eff1e1cb3613cff94fd1cbd3e65ce9407a29dd9efd8ffa78f90443fa72 +size 3129 diff --git a/prompts/Segmentation/5.png b/prompts/Segmentation/5.png new file mode 100644 index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd --- /dev/null +++ b/prompts/Segmentation/5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b +size 21984 diff --git a/prompts/Segmentation/5_seg.png b/prompts/Segmentation/5_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..4356240ef23bfbdff9edbf1a5ea36d75bae5e93a --- /dev/null +++ b/prompts/Segmentation/5_seg.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fda0b493538ed52658c31498c678ce098229bd9c57c3e76db9096298f2b2309d +size 1814 diff --git a/prompts/Segmentation/6.png b/prompts/Segmentation/6.png new file mode 100644 index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252 --- /dev/null +++ b/prompts/Segmentation/6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547 +size 30704 diff --git a/prompts/Segmentation/6_seg.png b/prompts/Segmentation/6_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..d0aa3baf6aed286426f0dd6a6190ce0c094333df --- /dev/null +++ b/prompts/Segmentation/6_seg.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10e8d222b3db095de09124c8b129ac63da360776d588b1fb5027b31fa9b5d1d0 +size 1684 diff --git a/prompts/Segmentation/7.png b/prompts/Segmentation/7.png new file mode 100644 index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6 --- /dev/null +++ b/prompts/Segmentation/7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7 +size 49450 diff --git a/prompts/Segmentation/7_seg.png b/prompts/Segmentation/7_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..f768b93c59c7e66c0617d1160a5381080b8b44f4 --- /dev/null +++ b/prompts/Segmentation/7_seg.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c11400a1bc6c696c2f479165e5ffe7d0be2212c252b311b29eeae2d111c927a +size 1887 diff --git a/prompts/Segmentation/8.png b/prompts/Segmentation/8.png new file mode 100644 index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d --- /dev/null +++ b/prompts/Segmentation/8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38 +size 50877 diff --git a/prompts/Surface Normal/1.png b/prompts/Surface Normal/1.png new file mode 100644 index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62 --- /dev/null +++ b/prompts/Surface Normal/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2 +size 48533 diff --git a/prompts/Surface Normal/1_surfave_norm.png b/prompts/Surface Normal/1_surfave_norm.png new file mode 100644 index 0000000000000000000000000000000000000000..e9865e90d1661296e48a27338eef101aa5b26ba3 --- /dev/null +++ b/prompts/Surface Normal/1_surfave_norm.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f40b8f98c35392de51e72db501b26dd04f3fa057287839df4076f0402c2fe05b +size 42875 diff --git a/prompts/Surface Normal/2.png b/prompts/Surface Normal/2.png new file mode 100644 index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5 --- /dev/null +++ b/prompts/Surface Normal/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b +size 54286 diff --git a/prompts/Surface Normal/2_surface_norm.png b/prompts/Surface Normal/2_surface_norm.png new file mode 100644 index 0000000000000000000000000000000000000000..6ca2da625560053b50957186fc30d9c9249333b5 --- /dev/null +++ b/prompts/Surface Normal/2_surface_norm.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e52e3619a44c896fbfcc4cd5a60435fb2187c907da2122c7a5584363e8b7109 +size 45167 diff --git a/prompts/Surface Normal/3.png b/prompts/Surface Normal/3.png new file mode 100644 index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18 --- /dev/null +++ b/prompts/Surface Normal/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00 +size 52593 diff --git a/prompts/Surface Normal/3_surface_norm.png b/prompts/Surface Normal/3_surface_norm.png new file mode 100644 index 0000000000000000000000000000000000000000..a17702432fea0029d9ad22445dac6b9ba1a2b4af --- /dev/null +++ b/prompts/Surface Normal/3_surface_norm.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2bb1d3d41c0f72faf0c4fca3c0030021650d0f36707cdfdb381d3ef9a5edb6b +size 49907 diff --git a/prompts/Surface Normal/4.png b/prompts/Surface Normal/4.png new file mode 100644 index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a --- /dev/null +++ b/prompts/Surface Normal/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d +size 60589 diff --git a/prompts/Surface Normal/4_surface_norm.png b/prompts/Surface Normal/4_surface_norm.png new file mode 100644 index 0000000000000000000000000000000000000000..b92d9199042136de71381a3a0d3463d070f96728 --- /dev/null +++ b/prompts/Surface Normal/4_surface_norm.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95d833513daa9de65dc68f3eeb4e589ec43d2e42e7305c789f4235f6213c5338 +size 46127 diff --git a/prompts/Surface Normal/5.png b/prompts/Surface Normal/5.png new file mode 100644 index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd --- /dev/null +++ b/prompts/Surface Normal/5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b +size 21984 diff --git a/prompts/Surface Normal/5_surface_norm.png b/prompts/Surface Normal/5_surface_norm.png new file mode 100644 index 0000000000000000000000000000000000000000..d30c9ee11141aceff4713cc580a857d876a7c9f8 --- /dev/null +++ b/prompts/Surface Normal/5_surface_norm.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9898e670f6c5e5bdd4bb4836b7ad9632ea90e2e91d57a524f22214c9cf6ef2cc +size 34050 diff --git a/prompts/Surface Normal/6.png b/prompts/Surface Normal/6.png new file mode 100644 index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252 --- /dev/null +++ b/prompts/Surface Normal/6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547 +size 30704 diff --git a/prompts/Surface Normal/6_surface_norm.png b/prompts/Surface Normal/6_surface_norm.png new file mode 100644 index 0000000000000000000000000000000000000000..8188e4041605d668f198a92ba661882417302e51 --- /dev/null +++ b/prompts/Surface Normal/6_surface_norm.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b87d3709a62bfced6176861c06102bb1fe45583e92fadacb5a38562bba34339 +size 45259 diff --git a/prompts/Surface Normal/7.png b/prompts/Surface Normal/7.png new file mode 100644 index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6 --- /dev/null +++ b/prompts/Surface Normal/7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7 +size 49450 diff --git a/prompts/Surface Normal/7_surface_norm.png b/prompts/Surface Normal/7_surface_norm.png new file mode 100644 index 0000000000000000000000000000000000000000..d8955586a49136487bbd689017efa9d73ad2544a --- /dev/null +++ b/prompts/Surface Normal/7_surface_norm.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84b9147bdcfa2c56c05bafc73d025efc251ba9b687dc9a83a00e2f9633c7a172 +size 44708 diff --git a/prompts/Surface Normal/8.png b/prompts/Surface Normal/8.png new file mode 100644 index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d --- /dev/null +++ b/prompts/Surface Normal/8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38 +size 50877 diff --git a/prompts/Synthetic Object Replication/Slide1.png b/prompts/Synthetic Object Replication/Slide1.png new file mode 100644 index 0000000000000000000000000000000000000000..aaa3ff562f23d81a718d76104668bd190ea6ebc7 --- /dev/null +++ b/prompts/Synthetic Object Replication/Slide1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d00ed3bc13809d3c1d54fbf851c334e3a0e6f088a4c7420c73af078e5517e05 +size 229453 diff --git a/prompts/Synthetic Object Replication/Slide2.png b/prompts/Synthetic Object Replication/Slide2.png new file mode 100644 index 0000000000000000000000000000000000000000..3de22bebc4cd4cc363f223ef323d48b0c75b3719 --- /dev/null +++ b/prompts/Synthetic Object Replication/Slide2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29aeaa8d76977173a58a957c463cb794a5a8d780b092706836fea4832b787f35 +size 235912 diff --git a/prompts/Synthetic Object Replication/Slide3.png b/prompts/Synthetic Object Replication/Slide3.png new file mode 100644 index 0000000000000000000000000000000000000000..e45d49883dc9528572525fc3b7e3ffff941eb68f --- /dev/null +++ b/prompts/Synthetic Object Replication/Slide3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0093baa4b93b0f75872637787cda5412fc371195b3f91188c1bf6e03dfdd49d2 +size 233967 diff --git a/prompts/Synthetic Object Replication/Slide4.png b/prompts/Synthetic Object Replication/Slide4.png new file mode 100644 index 0000000000000000000000000000000000000000..97991be222d6e76c92a53ffb76a3e307fff53641 --- /dev/null +++ b/prompts/Synthetic Object Replication/Slide4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fcce41ad98b8a7a3b7f1153b48813e936a06590bb716c184fbe012045063f815 +size 238995 diff --git a/prompts/Synthetic Object Replication/Slide5.png b/prompts/Synthetic Object Replication/Slide5.png new file mode 100644 index 0000000000000000000000000000000000000000..22565ba0e7b64c76e00fbb0519570f6ac35980b9 --- /dev/null +++ b/prompts/Synthetic Object Replication/Slide5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a686955f717be1a514f68786268394eea86485b3aba333913d62c6548c8417d +size 236567 diff --git a/prompts/Synthetic Object Replication/Slide6.png b/prompts/Synthetic Object Replication/Slide6.png new file mode 100644 index 0000000000000000000000000000000000000000..6bcb3367c8db69f98a63a9995e5147458d66ca11 --- /dev/null +++ b/prompts/Synthetic Object Replication/Slide6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f1b4e759376a1c26a837ff59a918ffd85a1689cb25064b39e8f7976aacb3ba7 +size 240757 diff --git a/prompts/Synthetic Object Replication/Slide7.png b/prompts/Synthetic Object Replication/Slide7.png new file mode 100644 index 0000000000000000000000000000000000000000..c5d7852cadabb41ae06e051d81d864916ed35abd --- /dev/null +++ b/prompts/Synthetic Object Replication/Slide7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1132a1a1f41a406069b628dcf2d81995ef48bc0a493d3e501c266c617d722499 +size 238630 diff --git a/prompts/Synthetic Object Replication/Slide8.png b/prompts/Synthetic Object Replication/Slide8.png new file mode 100644 index 0000000000000000000000000000000000000000..953743749cbd61eaa0d6371dadee67cccf4156a4 --- /dev/null +++ b/prompts/Synthetic Object Replication/Slide8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de4e4c3b27e1d7691b697704c6506081e2cd5903929dca66deab53c4be08d886 +size 235972 diff --git a/prompts/Synthetic Object Replication/Slide9.png b/prompts/Synthetic Object Replication/Slide9.png new file mode 100644 index 0000000000000000000000000000000000000000..7430d331a7cb500dbe9852d0e5ad4995b3d132d8 --- /dev/null +++ b/prompts/Synthetic Object Replication/Slide9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2dbf3f768e24d4e46c04bad22f41798d717a15c566ad3d6b0d0813e23feb995f +size 233023 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..18ebbfe08f54d5ba0a5ca1c4f58220ce4da99c4d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +numpy +scipy +matplotlib +seaborn +jupyter +tqdm +pillow +--extra-index-url https://download.pytorch.org/whl/cu118 +transformers==4.34.1 +torch==2.0.1 +einops +absl-py +ml_collections +requests +mlxu==0.1.11 +pydantic +fastapi +uvicorn +gradio +fastapi +uvicorn +opencv-python-headless +scikit-video +scikit-image +natsort +accelerate diff --git a/torch_vqvae_model.py b/torch_vqvae_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6053b36c5410da5f3a914e1a6644f840734311 --- /dev/null +++ b/torch_vqvae_model.py @@ -0,0 +1,257 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import einops +from einops.layers.torch import Rearrange + + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + +def swish(x): + return x*torch.sigmoid(x) + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, activation_fn="relu"): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) + self.activation_fn = activation_fn + if activation_fn=="relu": + self.actn = nn.ReLU() + + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + if self.activation_fn=="relu": + x = self.actn(x) + elif self.activation_fn=="swish": + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + if self.activation_fn=="relu": + x = self.actn(x) + elif self.activation_fn=="swish": + x = swish(x) + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + +class Encoder(nn.Module): + def __init__(self, ): + super().__init__() + + self.filters = 128 + self.num_res_blocks = 2 + self.ch_mult = [1,1,2,2,4] + self.in_ch_mult = (1,)+tuple(self.ch_mult) + self.embedding_dim = 32 + self.conv_downsample = False + + self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False) + blocks = [] + for i in range(len(self.ch_mult)): + block_in_ch = self.filters * self.in_ch_mult[i] + block_out_ch = self.filters * self.ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) + block_in_ch = block_out_ch + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) + self.norm1 = normalize(block_in_ch) + self.conv2 = nn.Conv2d(block_in_ch, self.embedding_dim, kernel_size=1, stride=1, padding=0) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + x = self.conv1(x) + for i in range(len(self.ch_mult)): + for j in range(self.num_res_blocks): + x = self.blocks[i*2+j](x) + + if i < len(self.ch_mult) -1: + x = torch.nn.functional.avg_pool2d(x, (2,2),(2,2)) + + x = self.blocks[-2](x) + x = self.blocks[-1](x) + + x = self.norm1(x) + x = swish(x) + x = self.conv2(x) + return x + +class VectorQuantizer(nn.Module): + def __init__(self, codebook_size=8192, emb_dim=32, beta=None): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) + self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) + self.beta=0.0 + self.z_dim = emb_dim + + def forward(self, z): + # preprocess + + b, c, h, w = z.size() + flatten = z.permute(0, 2, 3, 1).reshape(-1, c) + codebook = self.embedding.weight + with torch.no_grad(): + tokens = torch.cdist(flatten, codebook).argmin(dim=1) + quantized = F.embedding(tokens, + codebook).view(b, h, w, c).permute(0, 3, 1, 2) + + # compute loss + codebook_loss = F.mse_loss(quantized, z.detach()) + commitment_loss = F.mse_loss(quantized.detach(), z) + loss = codebook_loss + self.beta * commitment_loss + + # perplexity + counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) + # dist.all_reduce(counts) + p = counts / counts.sum() + perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) + + # postprocess + tokens = tokens.view(b, h, w) + quantized = z + (quantized - z).detach() + + # quantized_2 = self.get_codebook_feat(tokens, (b, h, w, c)) + + return quantized, tokens, loss, perplexity + + + def get_codebook_feat(self, indices, shape=None): + # input indices: batch*token_num -> (batch*token_num)*1 + # shape: batch, height, width, channel + indices = indices.view(-1,1) + min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) + min_encodings.scatter_(1, indices, 1) + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: # reshape back to match original input shape + z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() + + return z_q + + +class Decoder(nn.Module): + def __init__(self,): + super().__init__() + self.filters = 128 + self.num_res_blocks = 2 + self.ch_mult = [1,1,2,2,4] + self.in_ch_mult = (1,)+tuple(self.ch_mult) + self.embedding_dim =32 + self.out_channels = 3 + self.in_channels = self.embedding_dim + self.conv_downsample = False + + self.conv1 = nn.Conv2d(32, 512, kernel_size=3, stride=1, padding=1) + blocks = [] + block_in_ch = self.filters * self.ch_mult[-1] + block_out_ch = self.filters * self.ch_mult[-1] + #blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) + upsample_conv_layers = [] + for i in reversed(range(len(self.ch_mult))): + block_out_ch = self.filters * self.ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) + block_in_ch = block_out_ch + if i > 0: + upsample_conv_layers.append(nn.Conv2d(block_in_ch, block_out_ch*4, kernel_size=3, stride=1, padding=1)) + + self.upsample = Rearrange("b h w (h2 w2 c) -> b (h h2) (w w2) c", h2=2, w2=2) + self.norm1 = normalize(block_in_ch) + # self.act_fn + self.conv6 = nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1) + self.blocks = nn.ModuleList(blocks) + self.up_convs = nn.ModuleList(upsample_conv_layers) + + def forward(self, x): + x = self.conv1(x) + x = self.blocks[0](x) + x = self.blocks[1](x) + for i in range(len(self.ch_mult)): + for j in range(self.num_res_blocks): + x = self.blocks[2+i*2+j](x) + if i < len(self.ch_mult)-1: + x = self.up_convs[i](x) + #print("pre: x.size()",x.size()) + x = x.permute(0,2,3,1) + x = self.upsample(x) + x = x.permute(0,3,1,2) + #print("post: x.size()", x.size()) + x = self.norm1(x) + x = swish(x) + x = self.conv6(x) + return x + + +class VQVAE(nn.Module): + def __init__(self, ): + super().__init__() + self.encoder = Encoder() + self.quantizer = VectorQuantizer() + self.decoder = Decoder() + + def forward(self, x): + x = self.encoder(x) + quant,tokens, loss, perplexity = self.quantizer(x) + x = self.decoder(quant) + return x + + def tokenize(self, x): + batch_shape = x.shape[:-3] + x = x.reshape(-1, *x.shape[-3:]) + x = self.encoder(x) + quant,tokens, loss, perplexity = self.quantizer(x) + return tokens.reshape(*batch_shape, *tokens.shape[1:]) + + def decode(self, tokens): + tokens = einops.rearrange(tokens, 'b ... -> b (...)') + b = tokens.shape[0] + if tokens.shape[-1] == 256: + hw = 16 + elif tokens.shape[-1] == 224: + hw = 14 + else: + raise ValueError("Invalid tokens shape") + quant = self.quantizer.get_codebook_feat(tokens, (b, hw, hw, 32)) + x = self.decoder(quant) + return x + + +class VAEDecoder(nn.Module): + def __init__(self, ): + super().__init__() + self.quantizer = VectorQuantizer() + self.decoder = Decoder() + + def forward(self, x): + quant = self.quantizer.get_codebook_feat(x,(1,14,14,32)) + x = self.decoder(quant) + return x + + +def get_tokenizer(): + checkpoint_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "xh_ckpt.pth" + ) + torch_state_dict = torch.load(checkpoint_path) + net = VQVAE() + net.load_state_dict(torch_state_dict) + return net + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b059a22e2fa8786ebba48dab800318b362507715 --- /dev/null +++ b/utils.py @@ -0,0 +1,296 @@ +import os +from multiprocessing import Pool +import numpy as np +import random +from PIL import Image +import re +import cv2 +import glob +from natsort import natsorted + + +class MultiProcessImageSaver(object): + + def __init__(self, n_workers=1): + self.pool = Pool(n_workers) + + def __call__(self, images, output_files, resizes=None): + if resizes is None: + resizes = [None for _ in range(len(images))] + return self.pool.imap( + self.save_image, + zip(images, output_files, resizes), + ) + + def close(self): + self.pool.close() + self.pool.join() + + @staticmethod + def save_image(args): + image, filename, resize = args + image = Image.fromarray(image) + if resize is not None: + image = image.resize(tuple(resize)) + image.save(filename) + + +def list_dir_with_full_path(path): + return [os.path.join(path, f) for f in os.listdir(path)] + + +def find_all_files_in_dir(path): + files = [] + for root, _, files in os.walk(path): + for file in files: + files.append(os.path.join(root, file)) + return files + + +def is_image(path): + return ( + path.endswith('.jpg') + or path.endswith('.png') + or path.endswith('.jpeg') + or path.endswith('.JPG') + or path.endswith('.PNG') + or path.endswith('.JPEG') + ) + + +def is_video(path): + return ( + path.endswith('.mp4') + or path.endswith('.avi') + or path.endswith('.MP4') + or path.endswith('.AVI') + or path.endswith('.webm') + or path.endswith('.WEBM') + or path.endswith('.mkv') + or path.endswith('.MVK') + ) + + +def random_square_crop(img, random_generator=None): + # If no random generator is provided, use numpy's default + if random_generator is None: + random_generator = np.random.default_rng() + + # Get the width and height of the image + width, height = img.size + + # Determine the shorter side + min_size = min(width, height) + + # Randomly determine the starting x and y coordinates for the crop + if width > height: + left = random_generator.integers(0, width - min_size) + upper = 0 + else: + left = 0 + upper = random_generator.integers(0, height - min_size) + + # Calculate the ending x and y coordinates for the crop + right = left + min_size + lower = upper + min_size + + # Crop the image + return img.crop((left, upper, right, lower)) + + +def read_image_to_tensor(path, center_crop=1.0): + pil_im = Image.open(path).convert('RGB') + if center_crop < 1.0: + width, height = pil_im.size + pil_im = pil_im.crop(( + int((1 - center_crop) * height / 2), int((1 + center_crop) * height / 2), + int((1 - center_crop) * width / 2), int((1 + center_crop) * width / 2), + )) + input_img = pil_im.resize((256, 256)) + input_img = np.array(input_img) / 255.0 + input_img = input_img.astype(np.float32) + return input_img + + +def match_mulitple_path(root_dir, regex): + videos = [] + for root, _, files in os.walk(root_dir): + for file in files: + videos.append(os.path.join(root, file)) + + videos = [v for v in videos if not v.split('/')[-1].startswith('.')] + + grouped_path = {} + for r in regex: + r = re.compile(r) + for v in videos: + matched = r.findall(v) + if len(matched) > 0: + groups = matched[0] + if groups not in grouped_path: + grouped_path[groups] = [] + grouped_path[groups].append(v) + + grouped_path = { + k: tuple(v) for k, v in grouped_path.items() + if len(v) == len(regex) + } + return list(grouped_path.values()) + + +def randomly_subsample_frame_indices(length, n_frames, max_stride=30, random_start=True): + assert length >= n_frames + max_stride = min( + (length - 1) // (n_frames - 1), + max_stride + ) + stride = np.random.randint(1, max_stride + 1) + if random_start: + start = np.random.randint(0, length - (n_frames - 1) * stride) + else: + start = 0 + return np.arange(n_frames) * stride + start + + +def read_frames_from_dir(dir_path, n_frames, stride, random_start=True, center_crop=1.0): + files = [os.path.join(dir_path, x) for x in os.listdir(dir_path)] + files = natsorted([x for x in files if is_image(x)]) + + total_frames = len(files) + + if total_frames < n_frames: + return None + + max_stride = (total_frames - 1) // (n_frames - 1) + stride = min(max_stride, stride) + + if random_start: + start = np.random.randint(0, total_frames - (n_frames - 1) * stride) + else: + start = 0 + frame_indices = np.arange(n_frames) * stride + start + + frames = [] + for frame_index in sorted(frame_indices): + # Check if the frame_index is valid + frames.append(read_image_to_tensor(files[frame_index], center_crop=center_crop)) + if len(frames) < n_frames: + return None + frames = np.stack(frames, axis=0) + return frames + + +def read_frames_from_video(video_path, n_frames, stride, random_start=True, center_crop=1.0): + + frames = [] + cap = cv2.VideoCapture(video_path) + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if total_frames < n_frames: + cap.release() + return None + + max_stride = (total_frames - 1) // (n_frames - 1) + stride = min(max_stride, stride) + + if random_start: + start = np.random.randint(0, total_frames - (n_frames - 1) * stride) + else: + start = 0 + frame_indices = np.arange(n_frames) * stride + start + + for frame_index in sorted(frame_indices): + # Check if the frame_index is valid + if 0 <= frame_index < total_frames: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) + ret, frame = cap.read() + if ret: + if center_crop < 1.0: + height, width, _ = frame.shape + frame = frame[ + int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2), + int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2), + : + ] + frame = cv2.resize(frame, (256, 256)) + + frames.append(frame) + + else: + print(f"Frame index {frame_index} is out of bounds. Skipping...") + + cap.release() + if len(frames) < n_frames: + return None + frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 + + # From BGR to RGB + return np.stack( + [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 + ) + + +def read_all_frames_from_video(video_path, center_crop=1.0): + + frames = [] + cap = cv2.VideoCapture(video_path) + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + + for frame_index in range(total_frames): + # Check if the frame_index is valid + if 0 <= frame_index < total_frames: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) + ret, frame = cap.read() + if ret: + if center_crop < 1.0: + height, width, _ = frame.shape + frame = frame[ + int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2), + int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2), + : + ] + frames.append(cv2.resize(frame, (256, 256))) + else: + print(f"Frame index {frame_index} is out of bounds. Skipping...") + + cap.release() + if len(frames) == 0: + return None + frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 + # From BGR to RGB + return np.stack( + [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 + ) + + +def read_max_span_frames_from_video(video_path, n_frames): + frames = [] + cap = cv2.VideoCapture(video_path) + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if total_frames < n_frames: + cap.release() + return None + stride = (total_frames - 1) // (n_frames - 1) + frame_indices = np.arange(n_frames) * stride + + frames = [] + for frame_index in frame_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) + ret, frame = cap.read() + if ret: + frames.append(cv2.resize(frame, (256, 256))) + + cap.release() + if len(frames) < n_frames: + return None + + frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 + # From BGR to RGB + return np.stack( + [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 + ) + diff --git a/vqvae/.DS_Store b/vqvae/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..38734ca2de71d90578b12a191d5ff30a57f26d5c Binary files /dev/null and b/vqvae/.DS_Store differ diff --git a/vqvae/__init__.py b/vqvae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b32aee9a68b1192ae0be7214ca92f35defd717 --- /dev/null +++ b/vqvae/__init__.py @@ -0,0 +1,25 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.0.1" + +# from .modeling_ema import EMAModel +# from .modeling_maskgit_vqgan import MaskGitVQGAN +# from .modeling_movq import MOVQ +# from .modeling_paella_vq import PaellaVQModel +# from .modeling_utils import VQGANModel +# from .modeling_transformer import MaskGitTransformer, MaskGiTUViT +# from .pipeline_muse import PipelineMuse, PipelineMuseInpainting +# from .sampling import get_mask_chedule diff --git a/vqvae/logging.py b/vqvae/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..65814a82380e47e54434c4be97026141772f7298 --- /dev/null +++ b/vqvae/logging.py @@ -0,0 +1,338 @@ +# coding=utf-8 +# Copyright 2023 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import CRITICAL # NOQA +from logging import DEBUG # NOQA +from logging import ERROR # NOQA +from logging import FATAL # NOQA +from logging import INFO # NOQA +from logging import NOTSET # NOQA +from logging import WARN # NOQA +from logging import WARNING # NOQA +from typing import Optional + +from tqdm import auto as tqdm_lib + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + +_tqdm_active = True + + +def _get_default_logging_level(): + """ + If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("muse_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict(): + return log_levels + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom muse module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the 🤗 muse' root logger as an int. + + Returns: + `int`: The logging level. + + + + 🤗 muse has following logging levels: + + - 50: `muse.logging.CRITICAL` or `muse.logging.FATAL` + - 40: `muse.logging.ERROR` + - 30: `muse.logging.WARNING` or `muse.logging.WARN` + - 20: `muse.logging.INFO` + - 10: `muse.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the 🤗 muse' root logger. + + Args: + verbosity (`int`): + Logging level, e.g., one of: + + - `muse.logging.CRITICAL` or `muse.logging.FATAL` + - `muse.logging.ERROR` + - `muse.logging.WARNING` or `muse.logging.WARN` + - `muse.logging.INFO` + - `muse.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info(): + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning(): + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug(): + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error(): + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the HuggingFace muse' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the HuggingFace muse' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace muse' root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace muse' root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for HuggingFace muse' loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/vqvae/modeling_utils.py b/vqvae/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fdbdb68a7c2f154c670fe40950c035fad06e4691 --- /dev/null +++ b/vqvae/modeling_utils.py @@ -0,0 +1,1171 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import inspect +import json +import os +from collections import OrderedDict +from functools import partial +from pathlib import PosixPath +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import accelerate +import numpy as np +import torch +from accelerate.utils import set_module_tensor_to_device +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from requests import HTTPError +from torch import Tensor, device + +from . import __version__, logging + +logger = logging.get_logger(__name__) + + +hf_cache_home = os.path.expanduser( + os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +default_cache_path = os.path.join(hf_cache_home, "muse") + + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "pytorch_model.bin" +SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors" +HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" +MUSE_CACHE = default_cache_path +MUSE_DYNAMIC_MODULE_NAME = "myse_modules" +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) + + +_LOW_CPU_MEM_USAGE_DEFAULT = True + + +def get_parameter_device(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike]): + """ + Reads a checkpoint file, returning properly formatted errors if they arise. + """ + try: + if os.path.basename(checkpoint_file) == WEIGHTS_NAME: + return torch.load(checkpoint_file, map_location="cpu") + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def _load_state_dict_into_model(model_to_load, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + + return error_msgs + + +def _get_model_file( + pretrained_model_name_or_path, + *, + weights_name, + subfolder, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + use_auth_token, + user_agent, + revision, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + return pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, weights_name) + return model_file + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + return model_file + else: + raise EnvironmentError( + f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + return model_file + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {weights_name} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {weights_name}" + ) + + +class ModelMixin(torch.nn.Module): + r""" + Base class for all models. + + [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading + and saving models. + + - **config_name** ([`str`]) -- A filename under which the model should be stored when calling + [`~models.ModelMixin.save_pretrained`]. + """ + config_name = CONFIG_NAME + _automatically_saved_args = ["_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + + def __init__(self): + super().__init__() + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if not self._supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def disable_gradient_checkpointing(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self._supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import UNet2DConditionModel + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> model = UNet2DConditionModel.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 + ... ) + >>> model = model.to("cuda") + >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + ``` + """ + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.set_use_memory_efficient_attention_xformers(False) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = None, + state_dict: Optional[Dict[str, torch.Tensor]] = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~models.ModelMixin.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + state_dict (`Dict[str, torch.Tensor]`, *optional*): + The state dictionary to save. If `None`, the model's state dictionary will be saved. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # Save the model + if state_dict is None: + state_dict = model_to_save.state_dict() + + weights_name = WEIGHTS_NAME + + # Save the model + save_function(state_dict, os.path.join(save_directory, weights_name)) + + logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + """ + cache_dir = kwargs.pop("cache_dir", MUSE_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", False) # TODO + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # Load model + + model_file = None + + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + config, unused_kwargs = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + model = cls.from_config(config, **unused_kwargs) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + for param_name, param in state_dict.items(): + accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) + if accepts_dtype: + set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype) + else: + set_module_tensor_to_device(model, param_name, param_device, value=param) + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by deafult the device_map is None and the weights are loaded on the CPU + accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype) + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + config, unused_kwargs = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = [k for k in state_dict.keys()] + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) + else: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + +""" ConfigMixin base class and utilities.""" + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + +class ConfigMixin: + r""" + Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all + methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with + - [`~ConfigMixin.from_config`] + - [`~ConfigMixin.save_config`] + + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overridden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). + """ + config_name = None + ignore_for_config = [] + has_compatibles = False + + _deprecated_kwargs = [] + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + # Special case for `kwargs` used in deprecation warning added to schedulers + # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, + # or solve in a more general way. + kwargs.pop("kwargs", None) + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + logger.debug(f"Updating config from {previous_dict} to {internal_dict}") + + self._internal_dict = FrozenDict(internal_dict) + + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~ConfigMixin.from_config`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) + + self.to_json_file(output_config_file) + logger.info(f"Configuration saved in {output_config_file}") + + @classmethod + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, **kwargs): + r""" + Instantiate a Python class from a config dictionary + + Parameters: + config (`Dict[str, Any]`): + A config dictionary from which the Python class will be instantiated. Make sure to only load + configuration files of compatible classes. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the Python class. + `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually + overwrite same named arguments of `config`. + + Examples: + + ```python + >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler + + >>> # Download scheduler from huggingface.co and cache. + >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") + + >>> # Instantiate DDIM scheduler class with same config as DDPM + >>> scheduler = DDIMScheduler.from_config(scheduler.config) + + >>> # Instantiate PNDM scheduler class with same config as DDPM + >>> scheduler = PNDMScheduler.from_config(scheduler.config) + ``` + """ + # <===== TO BE REMOVED WITH DEPRECATION + # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + + if config is None: + raise ValueError("Please make sure to provide a config as the first positional argument.") + # ======> + + # Return model and optionally state and/or unused_kwargs + model = cls(**config) + return model + + @classmethod + def load_config( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + r""" + Instantiate a Python class from a config dictionary + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + """ + cache_dir = kwargs.pop("cache_dir", MUSE_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = {"file_type": "config"} + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from " + "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + # Load from a PyTorch checkpoint + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + ): + config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) + else: + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + config_file = hf_hub_download( + pretrained_model_name_or_path, + filename=cls.config_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" + " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" + " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" + " login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" + " this model name. Check the model page at" + f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." + ) + except HTTPError as err: + raise EnvironmentError( + "There was a specific connection error when trying to load" + f" {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" + " run the library in offline mode at" + " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {cls.config_name} file" + ) + + try: + # Load config dict + config_dict = cls._dict_from_json_file(config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + + if return_unused_kwargs: + return config_dict, kwargs + + return config_dict + + @staticmethod + def _get_init_keys(cls): + return set(dict(inspect.signature(cls.__init__).parameters).keys()) + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @property + def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_version"] = __version__ + + def to_json_saveable(value): + if isinstance(value, np.ndarray): + value = value.tolist() + elif isinstance(value, PosixPath): + value = str(value) + return value + + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init diff --git a/vqvae_muse.py b/vqvae_muse.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d57f7e54c46b71c75bff0124f02d6df146d7cb --- /dev/null +++ b/vqvae_muse.py @@ -0,0 +1,594 @@ +# coding=utf-8 +# Copyright 2023 The Taming Transformers Authors and The HuggingFace Inc. team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import partial +from typing import Tuple +import os +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from vqvae.modeling_utils import ConfigMixin, ModelMixin, register_to_config + + +class Upsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, hidden_states): + hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + if self.with_conv: + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Downsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool): + super().__init__() + + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, hidden_states): + if self.with_conv: + pad = (0, 1, 0, 1) # pad height and width dim + hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0) + hidden_states = self.conv(hidden_states) + else: + hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2) + return hidden_states + + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = None, + use_conv_shortcut: bool = False, + dropout_prob: float = 0.0, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels + self.use_conv_shortcut = use_conv_shortcut + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d( + self.in_channels, + self.out_channels_, + kernel_size=3, + stride=1, + padding=1, + ) + + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True) + self.dropout = nn.Dropout(dropout_prob) + self.conv2 = nn.Conv2d( + self.out_channels_, + self.out_channels_, + kernel_size=3, + stride=(1, 1), + padding=1, + ) + + if self.in_channels != self.out_channels_: + if use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + self.in_channels, + self.out_channels_, + kernel_size=3, + stride=1, + padding=1, + ) + else: + self.nin_shortcut = nn.Conv2d( + self.in_channels, + self.out_channels_, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels_: + if self.use_conv_shortcut: + residual = self.conv_shortcut(residual) + else: + residual = self.nin_shortcut(residual) + + return hidden_states + residual + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + + self.in_channels = in_channels + conv = partial(nn.Conv2d, self.in_channels, self.in_channels, kernel_size=1, stride=1, padding=0) + + self.norm = nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True) + self.q, self.k, self.v = conv(), conv(), conv() + self.proj_out = conv() + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm(hidden_states) + + query = self.q(hidden_states) + key = self.k(hidden_states) + value = self.v(hidden_states) + + # compute attentions + batch, channels, height, width = query.shape + query = query.reshape((batch, channels, height * width)) + query = query.permute(0, 2, 1) # (b, hw, c) + key = key.reshape((batch, channels, height * width)) + + attn_weights = torch.bmm(query, key) # b,hw,hw + attn_weights = attn_weights * (int(channels) ** -0.5) + attn_weights = nn.functional.softmax(attn_weights, dim=2) + + # attend to values + value = value.reshape((batch, channels, height * width)) + attn_weights = attn_weights.permute(0, 2, 1) + hidden_states = torch.bmm(value, attn_weights) + hidden_states = hidden_states.reshape((batch, channels, height, width)) + + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class UpsamplingBlock(nn.Module): + def __init__(self, config, curr_res: int, block_idx: int): + super().__init__() + + self.config = config + self.block_idx = block_idx + self.curr_res = curr_res + + if self.block_idx == self.config.num_resolutions - 1: + block_in = self.config.hidden_channels * self.config.channel_mult[-1] + else: + block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1] + + block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx] + + res_blocks = [] + attn_blocks = [] + for _ in range(self.config.num_res_blocks + 1): + res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout)) + block_in = block_out + if self.curr_res in self.config.attn_resolutions: + attn_blocks.append(AttnBlock(block_in)) + + self.block = nn.ModuleList(res_blocks) + self.attn = nn.ModuleList(attn_blocks) + + self.upsample = None + if self.block_idx != 0: + self.upsample = Upsample(block_in, self.config.resample_with_conv) + + def forward(self, hidden_states): + for i, res_block in enumerate(self.block): + hidden_states = res_block(hidden_states) + if len(self.attn) > 1: + hidden_states = self.attn[i](hidden_states) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class DownsamplingBlock(nn.Module): + def __init__(self, config, curr_res: int, block_idx: int): + super().__init__() + + self.config = config + self.curr_res = curr_res + self.block_idx = block_idx + + in_channel_mult = (1,) + tuple(self.config.channel_mult) + block_in = self.config.hidden_channels * in_channel_mult[self.block_idx] + block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx] + + res_blocks = nn.ModuleList() + attn_blocks = nn.ModuleList() + for _ in range(self.config.num_res_blocks): + res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout)) + block_in = block_out + if self.curr_res in self.config.attn_resolutions: + attn_blocks.append(AttnBlock(block_in)) + + self.block = res_blocks + self.attn = attn_blocks + + self.downsample = None + if self.block_idx != self.config.num_resolutions - 1: + self.downsample = Downsample(block_in, self.config.resample_with_conv) + + def forward(self, hidden_states): + for i, res_block in enumerate(self.block): + hidden_states = res_block(hidden_states) + if len(self.attn) > 1: + hidden_states = self.attn[i](hidden_states) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class MidBlock(nn.Module): + def __init__(self, config, in_channels: int, no_attn: False, dropout: float): + super().__init__() + + self.config = config + self.in_channels = in_channels + self.no_attn = no_attn + self.dropout = dropout + + self.block_1 = ResnetBlock( + self.in_channels, + self.in_channels, + dropout_prob=self.dropout, + ) + if not no_attn: + self.attn_1 = AttnBlock(self.in_channels) + self.block_2 = ResnetBlock( + self.in_channels, + self.in_channels, + dropout_prob=self.dropout, + ) + + def forward(self, hidden_states): + hidden_states = self.block_1(hidden_states) + if not self.no_attn: + hidden_states = self.attn_1(hidden_states) + hidden_states = self.block_2(hidden_states) + return hidden_states + + +class Encoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + + # downsampling + self.conv_in = nn.Conv2d( + self.config.num_channels, + self.config.hidden_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + curr_res = self.config.resolution + downsample_blocks = [] + for i_level in range(self.config.num_resolutions): + downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level)) + + if i_level != self.config.num_resolutions - 1: + curr_res = curr_res // 2 + self.down = nn.ModuleList(downsample_blocks) + + # middle + mid_channels = self.config.hidden_channels * self.config.channel_mult[-1] + self.mid = MidBlock(config, mid_channels, self.config.no_attn_mid_block, self.config.dropout) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d( + mid_channels, + self.config.z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, pixel_values): + # downsampling + hidden_states = self.conv_in(pixel_values) + for block in self.down: + hidden_states = block(hidden_states) + + # middle + hidden_states = self.mid(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class Decoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + + # compute in_channel_mult, block_in and curr_res at lowest res + block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1] + curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1) + self.z_shape = (1, self.config.z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + self.config.z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + ) + + # middle + self.mid = MidBlock(config, block_in, self.config.no_attn_mid_block, self.config.dropout) + + # upsampling + upsample_blocks = [] + for i_level in reversed(range(self.config.num_resolutions)): + upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level)) + if i_level != 0: + curr_res = curr_res * 2 + self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order + + # end + block_out = self.config.hidden_channels * self.config.channel_mult[0] + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d( + block_out, + self.config.num_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, hidden_states): + # z to block_in + hidden_states = self.conv_in(hidden_states) + + # middle + hidden_states = self.mid(hidden_states) + + # upsampling + for block in reversed(self.up): + hidden_states = block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + Discretization bottleneck part of the VQ-VAE. + """ + + def __init__(self, num_embeddings, embedding_dim, commitment_cost): + r""" + Args: + num_embeddings: number of vectors in the quantized space. + embedding_dim: dimensionality of the tensors in the quantized space. + Inputs to the modules must be in this format as well. + commitment_cost: scalar which controls the weighting of the loss terms + (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta). + """ + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.commitment_cost = commitment_cost + + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings) + + def forward(self, hidden_states, return_loss=False): + """ + Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the + closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() + + distances = self.compute_distances(hidden_states) + min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1) + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape) + + # reshape to (batch, num_tokens) + min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1) + + # compute loss for embedding + loss = None + if return_loss: + loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean( + (z_q - hidden_states.detach()) ** 2 + ) + # preserve gradients + z_q = hidden_states + (z_q - hidden_states).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, min_encoding_indices, loss + + def compute_distances(self, hidden_states): + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim)) + emb_weights = self.embedding.weight.t() + + inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True) + codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True) + distances = torch.addmm( + inputs_norm_sq + codebook_t_norm_sq, + hidden_states_flattended, + emb_weights, + alpha=-2.0, + ) + return distances + + def get_codebook_entry(self, indices): + # indices are expected to be of shape (batch, num_tokens) + # get quantized latent vectors + batch, num_tokens = indices.shape + z_q = self.embedding(indices) + z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2) + return z_q + + # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372 + def get_soft_code(self, hidden_states, temp=1.0, stochastic=False): + hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel) + distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings) + + soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings) + if stochastic: + code = torch.multinomial(soft_code, 1) # (batch * height * width, 1) + else: + code = distances.argmin(dim=-1) # (batch * height * width) + + code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width) + batch, num_tokens = code.shape + soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings) + return soft_code, code + + def get_code(self, hidden_states): + # reshape z -> (batch, height, width, channel) + hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() + distances = self.compute_distances(hidden_states) + indices = torch.argmin(distances, axis=1).unsqueeze(1) + indices = indices.reshape(hidden_states.shape[0], -1) + return indices + + +class VQGANModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + resolution: int = 256, + num_channels: int = 3, + hidden_channels: int = 128, + channel_mult: Tuple = (1, 1, 2, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: int = (16,), + no_attn_mid_block: bool = False, + z_channels: int = 256, + num_embeddings: int = 1024, + quantized_embed_dim: int = 256, + dropout: float = 0.0, + resample_with_conv: bool = True, + commitment_cost: float = 0.25, + ): + super().__init__() + + self.config.num_resolutions = len(channel_mult) + self.config.reduction_factor = 2 ** (self.config.num_resolutions - 1) + self.config.latent_size = resolution // self.config.reduction_factor + + self.encoder = Encoder(self.config) + self.decoder = Decoder(self.config) + self.quantize = VectorQuantizer( + self.config.num_embeddings, self.config.quantized_embed_dim, self.config.commitment_cost + ) + self.quant_conv = nn.Conv2d( + self.config.z_channels, + self.config.quantized_embed_dim, + kernel_size=1, + ) + self.post_quant_conv = nn.Conv2d( + self.config.quantized_embed_dim, + self.config.z_channels, + kernel_size=1, + ) + + def encode(self, pixel_values, return_loss=False): + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss) + output = (quantized_states, codebook_indices) + if return_loss: + output = output + (codebook_loss,) + return output + + def decode(self, quantized_states): + hidden_states = self.post_quant_conv(quantized_states) + reconstructed_pixel_values = self.decoder(hidden_states) + return reconstructed_pixel_values + + def decode_code(self, codebook_indices): + quantized_states = self.quantize.get_codebook_entry(codebook_indices) + reconstructed_pixel_values = self.decode(quantized_states) + return reconstructed_pixel_values + + def get_code(self, pixel_values): + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + codebook_indices = self.quantize.get_code(hidden_states) + return codebook_indices + + def forward(self, pixel_values, return_loss=False): + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss) + reconstructed_pixel_values = self.decode(quantized_states) + outputs = (reconstructed_pixel_values, quantized_states, codebook_indices) + if return_loss: + outputs = outputs + (codebook_loss,) + return outputs + + + +def get_tokenizer_muse(): + + ckpts_path = "Emma02/vqvae_ckpts" + net = VQGANModel.from_pretrained(ckpts_path) + + return net