diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea --- /dev/null +++ b/.gitattributes @@ -0,0 +1,34 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ce75384ce0ff24309df0c1d799b647e0fa631e3d --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +*.pyc +pretrained_models diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..0df63cf507269d912eac0e915fc7a592adb520a4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,62 @@ +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 + +# These are all pre-defined +ENV DEBIAN_FRONTEND=noninteractive \ + TZ=Europe/Paris +# Install some basic utilities +RUN rm -f /etc/apt/sources.list.d/*.list && \ + apt-get update && apt-get install -y --no-install-recommends \ + sudo \ + git \ + curl \ + wget \ + ffmpeg libsm6 libxext6 \ + && rm -rf /var/lib/apt/lists/* +# Create a working directory +WORKDIR /app +# Create a non-root user and switch to it +RUN adduser --disabled-password --gecos '' --shell /bin/bash user \ + && chown -R user:user /app \ + && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user +USER user +# All users can use /home/user as their home directory +ENV HOME=/home/user \ + CONDA_AUTO_UPDATE_CONDA=false +ENV PATH=$HOME/miniconda/bin:$PATH +RUN mkdir $HOME/.cache $HOME/.config \ + && chmod -R 777 $HOME \ + && curl -sLo ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-py310_24.5.0-0-Linux-x86_64.sh \ + && chmod +x ~/miniconda.sh \ + && ~/miniconda.sh -b -p ~/miniconda \ + && rm ~/miniconda.sh \ + && conda clean -ya + +# From here are my stuff + +# Download models +RUN pip install --no-cache-dir gdown && \ + mkdir -p ./pretrained_models/GLIP/checkpoints && \ + mkdir -p ./pretrained_models/GLIP/configs && \ + mkdir -p ./pretrained_models/xvlm && \ + wget -nc -q -P ./pretrained_models/GLIP/checkpoints https://huggingface.co/GLIPModel/GLIP/resolve/main/glip_large_model.pth && \ + wget -nc -q -P ./pretrained_models/GLIP/configs https://raw.githubusercontent.com/microsoft/GLIP/main/configs/pretrain/glip_Swin_L.yaml && \ + gdown "https://drive.google.com/u/0/uc?id=1bv6_pZOsXW53EhlwU0ZgSk03uzFI61pN" -O ./pretrained_models/xvlm/retrieval_mscoco_checkpoint_9.pth + +# Python packages +RUN --mount=target=requirements.txt,source=requirements.txt \ + pip install --no-cache-dir torch torchvision && \ + pip install --no-cache-dir git+https://github.com/openai/CLIP.git && \ + pip install --no-cache-dir -r requirements.txt + +RUN python -c "from transformers import AutoModel; _ = AutoModel.from_pretrained('codellama/CodeLlama-7b-Python-hf')" +RUN python -c "from transformers import AutoModel; _ = AutoModel.from_pretrained('VDebugger/VDebugger-critic-generalist-7B')" +RUN python -c "from transformers import AutoModel; _ = AutoModel.from_pretrained('VDebugger/VDebugger-refiner-generalist-7B')" + +# Download GLIP dependencies, but unfortunately don't install yet... +RUN git clone https://github.com/sachit-menon/GLIP + +# Run gradio +COPY --link --chown=1000 ./ /app +EXPOSE 7860 +ENV GRADIO_SERVER_NAME="0.0.0.0" +CMD ["bash", "app.sh"] diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f87fcb2758d84e15d8fd265b197654bc68ea3dca --- /dev/null +++ b/README.md @@ -0,0 +1,15 @@ +--- +title: VDebugger generalist for VQA +emoji: 💬 +colorFrom: yellow +colorTo: purple +sdk: docker +sdk_version: 4.36.1 +app_file: app.py +pinned: false +license: apache-2.0 +models: +- codellama/CodeLlama-7b-Python-hf +- VDebugger/VDebugger-critic-generalist-7B +- VDebugger/VDebugger-refiner-generalist-7B +--- diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..388edbd9fa8e65c4e8de135bfd3035bf578200d0 --- /dev/null +++ b/app.py @@ -0,0 +1,333 @@ +import inspect +import json +import os +import random +from typing import Literal, cast + +import gradio as gr +import torch +from PIL import Image +from gradio.data_classes import InterfaceTypes +from gradio.flagging import CSVLogger +from torchvision import transforms +from transformers import AutoTokenizer, LlamaForCausalLM + +from trace_exec import run_program_with_trace, CompileTimeError +from vision_processes import load_models + +print("-" * 10, "Loading models...") +load_models() + +with open('joint.prompt') as f: + prompt_template = f.read().strip() + +INPUT_TYPE = 'image' +OUTPUT_TYPE = 'str' +SIGNATURE = f'def execute_command({INPUT_TYPE}) -> {OUTPUT_TYPE}:' + + +def generate(model, input_text): + torch.cuda.empty_cache() + print("-" * 10, "Before loading LLM:") + print(torch.cuda.memory_summary()) + + dtype = os.environ.get("CODELLAMA_DTYPE") + assert dtype in ['bfloat16', '8bit', '4bit', ] + tokenizer = AutoTokenizer.from_pretrained(model) + model = LlamaForCausalLM.from_pretrained( + model, + device_map="auto", + load_in_8bit=dtype == "8bit", + load_in_4bit=dtype == "4bit", + torch_dtype=torch.bfloat16 if dtype == "bfloat16" else None, + ) + print("-" * 10, "LLM loaded:") + print(model) + print(torch.cuda.memory_summary()) + + input_ids = tokenizer(input_text, return_tensors="pt").input_ids + generated_ids = model.generate( + input_ids.to('cuda'), max_new_tokens=256, stop_strings=["\n\n"], do_sample=False, tokenizer=tokenizer + ) + generated_ids = generated_ids[0][input_ids.shape[1]:] + text = tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + + del model + torch.cuda.empty_cache() + print("-" * 10, "After loading LLM:") + print(torch.cuda.memory_summary()) + + return text + + +def to_custom_trace(result, error, traced): + if traced is None: + assert isinstance(error, CompileTimeError) + traced = 'Compile Error' + return "-> {}\n\n--- Trace\n\n{}".format(result, traced) + + +def answer_from_trace(x): + assert x.startswith("->") + return x[2:].splitlines()[0].strip() + + +def debug(image, question, code, traced_info): + # critic + prompt = f"# Given an image: {question}\n{code}\n\n{traced_info}\n\n# Program is" + print("--- For debug: critic prompt is ---") + print(prompt) + print("---\n") + critic_out = generate("VDebugger/VDebugger-critic-generalist-7B", prompt) + incorrect = critic_out.strip().startswith('wrong') + critic_out = "# Program is" + critic_out + + if not incorrect: + yield code, traced_info, critic_out, "N/A", "N/A", answer_from_trace(traced_info) + return + else: + yield code, traced_info, critic_out, "RUNNING IN PROGRESS...", "", "" + + # refiner + critic_code = ('def execute_command' + critic_out.split('def execute_command')[1]).strip() + if '# Program is' in code: + critic_code = critic_code.split("# Program is")[0].strip() # errr, an awkward fix + prompt = f"# Given an image: {question}\n{critic_code}\n\n{traced_info}\n\n# Correction" + print("--- For debug: refiner prompt is ---") + print(prompt) + print("---\n") + refiner_out = generate("VDebugger/VDebugger-refiner-generalist-7B", prompt).strip() + yield code, traced_info, critic_out, refiner_out, "RUNNING IN PROGRESS...", "" + + # execute (again) + result, error, traced = run_program_with_trace(refiner_out, image, INPUT_TYPE, OUTPUT_TYPE) + traced_info_2 = to_custom_trace(result, error, traced) + + yield code, traced_info, critic_out, refiner_out, traced_info_2, answer_from_trace(traced_info_2) + + +def predict(image, question): + if image is None: + gr.Warning("Please provide an image", duration=5) + return + image = transforms.Compose([transforms.ToTensor()])(image) + + question = question.strip() + if question == "": + gr.Warning("Please provide a question", duration=5) + return + + # codellama + prompt = prompt_template.replace("INSERT_QUERY_HERE", f"Given an image: {question}\n{SIGNATURE}") + code = generate("codellama/CodeLlama-7b-Python-hf", prompt) + code = (SIGNATURE + code).strip() + yield code, "RUNNING IN PROGRESS...", "", "", "", "" + + # execute + result, error, traced = run_program_with_trace(code, image, INPUT_TYPE, OUTPUT_TYPE) + traced_info = to_custom_trace(result, error, traced) + yield code, traced_info, "RUNNING IN PROGRESS...", "", "", "" + + for tup in debug(image, question, code, traced_info): + yield tup + return + + +def re_debug(image, question, code, traced_info): + if code is None or code == "" or traced_info is None or traced_info == "": + gr.Warning("No prior debugging round", duration=5) + return + + yield code, traced_info, "RUNNING IN PROGRESS...", "", "", "" + for tup in debug(image, question, code, traced_info): + yield tup + return + + +DESCRIPTION = """# VDebugger + +| [Paper](https://arxiv.org/abs/2406.13444) | [Project](https://shirley-wu.github.io/vdebugger/) | [Code](https://github.com/shirley-wu/vdebugger/) | [Models and Data](https://huggingface.co/VDebugger) | + +**VDebugger** is a novel critic-refiner framework trained to localize and debug *visual programs* by tracking execution step by step. In this demo, we show the visual programs, the outputs from both the critic and the refiner, as well as the final result. + +**Warning:** Reduced performance and accuracy may be observed. Due to resource limitation of huggingface spaces, this demo runs Llama inference in 4-bit quantization and uses smaller foundation VLMs. For full capacity, please use the original code.""" + + +class MyInterface(gr.Interface): + def __init__(self): + super(gr.Interface, self).__init__( + title=None, + theme=None, + analytics_enabled=None, + mode="tabbed_interface", + css=None, + js=None, + head=None, + ) + self.interface_type = InterfaceTypes.STANDARD + self.description = DESCRIPTION + self.cache_examples = None + self.examples_per_page = 5 + self.example_labels = None + self.batch = False + self.live = False + self.api_name = "predict" + self.max_batch_size = 4 + self.concurrency_limit = 'default' + self.show_progress = "full" + self.allow_flagging = 'auto' + self.flagging_options = [("Flag", ""), ] + self.flagging_callback = CSVLogger() + self.flagging_dir = 'flagged' + + # Load examples + with open('examples/questions.json') as f: + example_questions = json.load(f) + self.examples = [] + for question in example_questions: + self.examples.append([ + Image.open('examples/{}.jpg'.format(question['imageId'])), question['question'], + ]) + + def load_random_example(): + image, question = random.choice(self.examples) + return image, question, "", "", "", "", "", "" + + # Render the Gradio UI + with self: + self.render_title_description() + + with gr.Row(): + image = gr.Image(label="Image", type="pil", width="30%", scale=1) + question = gr.Textbox(label="Question", scale=2) + + with gr.Row(): + _clear_btn = gr.ClearButton(value="Clear", variant="secondary") + _random_eg_btn = gr.Button("Random Example Input") + _submit_btn = gr.Button("Submit", variant="primary") + if inspect.isgeneratorfunction(predict) or inspect.isasyncgenfunction(predict): + _stop1_btn = gr.Button("Stop", variant="stop", visible=False) + _redebug_btn = gr.Button("Debug for Another Round", variant="primary") + if inspect.isgeneratorfunction(re_debug) or inspect.isasyncgenfunction(re_debug): + _stop2_btn = gr.Button("Stop", variant="stop", visible=False) + + with gr.Row(): + o1 = gr.Textbox(label="No debugging: program") + o2 = gr.Textbox(label="No debugging: execution") + + with gr.Row(): + o3 = gr.Textbox(label="VDebugger: critic") + o4 = gr.Textbox(label="VDebugger: refiner") + + with gr.Row(): + o5 = gr.Textbox(label="VDebugger: execution") + o6 = gr.Textbox(label="VDebugger: final answer") + + question.submit(fn=predict, inputs=[image, question], outputs=[o1, o2, o3, o4, o5, o6]) + _random_eg_btn.click(fn=load_random_example, outputs=[image, question, o1, o2, o3, o4, o5, o6]) + + async def cleanup(): + return [gr.Button(visible=True), gr.Button(visible=False)] + + # Setup redebug event + triggers = [_redebug_btn.click, ] + extra_output = [_redebug_btn, _stop2_btn] + predict_event = gr.on( + triggers, + gr.utils.async_lambda( + lambda: ( + gr.Button(visible=False), + gr.Button(visible=True), + ) + ), + inputs=None, + outputs=[_redebug_btn, _stop2_btn], + queue=False, + show_api=False, + ).then( + re_debug, + [image, question, o4, o5], + [o1, o2, o3, o4, o5, o6], + api_name=self.api_name, + scroll_to_output=False, + preprocess=not (self.api_mode), + postprocess=not (self.api_mode), + batch=self.batch, + max_batch_size=self.max_batch_size, + concurrency_limit=self.concurrency_limit, + show_progress=cast( + Literal["full", "minimal", "hidden"], self.show_progress + ), + ) + redebug_event = predict_event.then( + cleanup, + inputs=None, + outputs=extra_output, # type: ignore + queue=False, + show_api=False, + ) + _stop2_btn.click( + cleanup, + inputs=None, + outputs=[_redebug_btn, _stop2_btn], + cancels=predict_event, + queue=False, + show_api=False, + ) + + # Setup submit event + triggers = [_submit_btn.click, question.submit, ] + extra_output = [_submit_btn, _stop1_btn] + predict_event = gr.on( + triggers, + gr.utils.async_lambda( + lambda: ( + gr.Button(visible=False), + gr.Button(visible=True), + ) + ), + inputs=None, + outputs=[_submit_btn, _stop1_btn], + queue=False, + show_api=False, + ).then( + predict, + [image, question], + [o1, o2, o3, o4, o5, o6], + api_name=self.api_name, + scroll_to_output=False, + preprocess=not (self.api_mode), + postprocess=not (self.api_mode), + batch=self.batch, + max_batch_size=self.max_batch_size, + concurrency_limit=self.concurrency_limit, + show_progress=cast( + Literal["full", "minimal", "hidden"], self.show_progress + ), + ) + submit_event = predict_event.then( + cleanup, + inputs=None, + outputs=extra_output, # type: ignore + queue=False, + show_api=False, + ) + _stop1_btn.click( + cleanup, + inputs=None, + outputs=[_submit_btn, _stop1_btn], + cancels=predict_event, + queue=False, + show_api=False, + ) + + # Finally borrow Interface stuff + self.input_components = [image, question] + self.output_components = [o1, o2, o3, o4, o5, o6] + self.fn = predict + self.attach_clear_events(_clear_btn, None) + self.render_examples() + + +if __name__ == "__main__": + MyInterface().launch(share=os.environ.get("SHARE", '') != "") diff --git a/app.sh b/app.sh new file mode 100644 index 0000000000000000000000000000000000000000..ca254327c6d4a0081c205ddc5bd16b358ca7f9d5 --- /dev/null +++ b/app.sh @@ -0,0 +1,5 @@ +cd GLIP +python setup.py clean --all build develop --user +cd ../ +python -c "import maskrcnn_benchmark" # check successfully installed +python app.py diff --git a/examples/n111074.jpg b/examples/n111074.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f85bd1daec96377139ef9644654be899a99457ef Binary files /dev/null and b/examples/n111074.jpg differ diff --git a/examples/n113863.jpg b/examples/n113863.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bd51d08863bfed6b0d33800b9268f3b06e9c8245 Binary files /dev/null and b/examples/n113863.jpg differ diff --git a/examples/n11399.jpg b/examples/n11399.jpg new file mode 100644 index 0000000000000000000000000000000000000000..098e1a61e1bb4c45651602f69d7d3f2dd369e777 Binary files /dev/null and b/examples/n11399.jpg differ diff --git a/examples/n115850.jpg b/examples/n115850.jpg new file mode 100644 index 0000000000000000000000000000000000000000..017c5a829ccb429d450cfbc608bbb8a94a8cab8d Binary files /dev/null and b/examples/n115850.jpg differ diff --git a/examples/n116797.jpg b/examples/n116797.jpg new file mode 100644 index 0000000000000000000000000000000000000000..267b99b3c558bbb0ffa20bb6f8633e0b9511bb2d Binary files /dev/null and b/examples/n116797.jpg differ diff --git a/examples/n116868.jpg b/examples/n116868.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dc7788dadaca132c6c8ce3e52c469d127a2581bb Binary files /dev/null and b/examples/n116868.jpg differ diff --git a/examples/n132998.jpg b/examples/n132998.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c29aea7173f50a583061dbfd6b2e9a63d96cd5fa Binary files /dev/null and b/examples/n132998.jpg differ diff --git a/examples/n137739.jpg b/examples/n137739.jpg new file mode 100644 index 0000000000000000000000000000000000000000..99d70c5d2373d95c6e91ff0ecfe24e8791c6e395 Binary files /dev/null and b/examples/n137739.jpg differ diff --git a/examples/n140477.jpg b/examples/n140477.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e0a9bc527339c5ae1340ebb90b27ee67dedacfc0 Binary files /dev/null and b/examples/n140477.jpg differ diff --git a/examples/n14897.jpg b/examples/n14897.jpg new file mode 100644 index 0000000000000000000000000000000000000000..848f3b45832288767b77fab48942554ca26ac224 Binary files /dev/null and b/examples/n14897.jpg differ diff --git a/examples/n151233.jpg b/examples/n151233.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8aaf9286e4294f07dc84de1c5e2c8cf8fef44d80 Binary files /dev/null and b/examples/n151233.jpg differ diff --git a/examples/n154501.jpg b/examples/n154501.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ab3f8e506f8a817997c2eed6ef85ffa9712decc4 Binary files /dev/null and b/examples/n154501.jpg differ diff --git a/examples/n155638.jpg b/examples/n155638.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e9fef01ddea6e25a7debe0e5d15612e81324ace5 Binary files /dev/null and b/examples/n155638.jpg differ diff --git a/examples/n168871.jpg b/examples/n168871.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8359985031c43856fde8190d61500ada7d9c6d3c Binary files /dev/null and b/examples/n168871.jpg differ diff --git a/examples/n173361.jpg b/examples/n173361.jpg new file mode 100644 index 0000000000000000000000000000000000000000..af3bc16e2c518e16041fe7a4df7d8e772f23f00d Binary files /dev/null and b/examples/n173361.jpg differ diff --git a/examples/n173931.jpg b/examples/n173931.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1125d597d8c7f4ce578670fa6f12b7f663ca6156 Binary files /dev/null and b/examples/n173931.jpg differ diff --git a/examples/n176076.jpg b/examples/n176076.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ebdf8ac2290afee1a0681632a1d5b9f67fb49d21 Binary files /dev/null and b/examples/n176076.jpg differ diff --git a/examples/n177259.jpg b/examples/n177259.jpg new file mode 100644 index 0000000000000000000000000000000000000000..213c72178733a565346e4f05a95cb5a24ef283df Binary files /dev/null and b/examples/n177259.jpg differ diff --git a/examples/n177566.jpg b/examples/n177566.jpg new file mode 100644 index 0000000000000000000000000000000000000000..efca865af66a5e540355edb5cc60a920b55bd8a2 Binary files /dev/null and b/examples/n177566.jpg differ diff --git a/examples/n178654.jpg b/examples/n178654.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bd2b997c9ba0ca2c321202992a2cf93f442ac677 Binary files /dev/null and b/examples/n178654.jpg differ diff --git a/examples/n179572.jpg b/examples/n179572.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fcf60f60f18d46b48e6110144c2feef39aa9e5e3 Binary files /dev/null and b/examples/n179572.jpg differ diff --git a/examples/n183744.jpg b/examples/n183744.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9c2687f817994cc28fd954ba5df70bc5b148c0cb Binary files /dev/null and b/examples/n183744.jpg differ diff --git a/examples/n188669.jpg b/examples/n188669.jpg new file mode 100644 index 0000000000000000000000000000000000000000..18ff2c6e1350d9951717abee21ae1d9e4c48f13c Binary files /dev/null and b/examples/n188669.jpg differ diff --git a/examples/n193989.jpg b/examples/n193989.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fad1e8a49aad8bfef9286a03a96cae25e9e12d1e Binary files /dev/null and b/examples/n193989.jpg differ diff --git a/examples/n194711.jpg b/examples/n194711.jpg new file mode 100644 index 0000000000000000000000000000000000000000..03497d89639af771c9feea26b29e05e43cfd7c07 Binary files /dev/null and b/examples/n194711.jpg differ diff --git a/examples/n196522.jpg b/examples/n196522.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e468eda8e82699391f04a19987fa5c200c5f1418 Binary files /dev/null and b/examples/n196522.jpg differ diff --git a/examples/n209769.jpg b/examples/n209769.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a00efa546c81b0462c6f81fb804184e2e6e38708 Binary files /dev/null and b/examples/n209769.jpg differ diff --git a/examples/n210898.jpg b/examples/n210898.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bc2b946d25b420f1b9d49c078a308c85da3153b3 Binary files /dev/null and b/examples/n210898.jpg differ diff --git a/examples/n222443.jpg b/examples/n222443.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6ac2b65bc873c5ddde86a3a7acd69393ac7e7040 Binary files /dev/null and b/examples/n222443.jpg differ diff --git a/examples/n2381.jpg b/examples/n2381.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5611ee849011c2d12df02e454bad5d2de0030b1a Binary files /dev/null and b/examples/n2381.jpg differ diff --git a/examples/n238886.jpg b/examples/n238886.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1750e4805e7105cfdfe2c81330eb7e67822b0cfc Binary files /dev/null and b/examples/n238886.jpg differ diff --git a/examples/n241130.jpg b/examples/n241130.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c366d28fff35c38d72bec4717d97fef466d3eb22 Binary files /dev/null and b/examples/n241130.jpg differ diff --git a/examples/n241451.jpg b/examples/n241451.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b185973de555a773a7134216efc4622fdf0fc000 Binary files /dev/null and b/examples/n241451.jpg differ diff --git a/examples/n241713.jpg b/examples/n241713.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f27e998ec18c359607ed94843b200e34f186236a Binary files /dev/null and b/examples/n241713.jpg differ diff --git a/examples/n24680.jpg b/examples/n24680.jpg new file mode 100644 index 0000000000000000000000000000000000000000..46d876a331952b7316d80afd59a20ea3e96e2b63 Binary files /dev/null and b/examples/n24680.jpg differ diff --git a/examples/n249342.jpg b/examples/n249342.jpg new file mode 100644 index 0000000000000000000000000000000000000000..484c3cdbd4989cd45bc028f692b92f3d21f495bf Binary files /dev/null and b/examples/n249342.jpg differ diff --git a/examples/n25398.jpg b/examples/n25398.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a1b7259b5f34eebffbb95a3357adb4518a2985f0 Binary files /dev/null and b/examples/n25398.jpg differ diff --git a/examples/n256710.jpg b/examples/n256710.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e10e9a74c920a365b39f164fd4fcc190bf1c9156 Binary files /dev/null and b/examples/n256710.jpg differ diff --git a/examples/n272929.jpg b/examples/n272929.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5fd172764725de7a9a804fbf3186b5f3e7ef2083 Binary files /dev/null and b/examples/n272929.jpg differ diff --git a/examples/n278426.jpg b/examples/n278426.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f56c802e0e979b0a379b930bce0940e8b08e982f Binary files /dev/null and b/examples/n278426.jpg differ diff --git a/examples/n279408.jpg b/examples/n279408.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4191f913d6928e3461e187c3f6c2a409ab4f99b1 Binary files /dev/null and b/examples/n279408.jpg differ diff --git a/examples/n282460.jpg b/examples/n282460.jpg new file mode 100644 index 0000000000000000000000000000000000000000..da558a9e53d9b93a6e74b22a1028d9f37bf17583 Binary files /dev/null and b/examples/n282460.jpg differ diff --git a/examples/n288083.jpg b/examples/n288083.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e86974b3b3d86867c6c53216c0f331c45fa99552 Binary files /dev/null and b/examples/n288083.jpg differ diff --git a/examples/n291937.jpg b/examples/n291937.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1cf1b1d951dcabba354a1b4de6790a2fa4d585df Binary files /dev/null and b/examples/n291937.jpg differ diff --git a/examples/n305396.jpg b/examples/n305396.jpg new file mode 100644 index 0000000000000000000000000000000000000000..25adb833482ddcae1cba96dbcb510e2101d175e5 Binary files /dev/null and b/examples/n305396.jpg differ diff --git a/examples/n306745.jpg b/examples/n306745.jpg new file mode 100644 index 0000000000000000000000000000000000000000..32c36e06da007f024f2d296dce3a226a345c05e3 Binary files /dev/null and b/examples/n306745.jpg differ diff --git a/examples/n308857.jpg b/examples/n308857.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7eeaf4131de70df983dd07469efc6c9cc5e5e8e2 Binary files /dev/null and b/examples/n308857.jpg differ diff --git a/examples/n310594.jpg b/examples/n310594.jpg new file mode 100644 index 0000000000000000000000000000000000000000..06d626206478945e4185f55ecc67fcad6961843c Binary files /dev/null and b/examples/n310594.jpg differ diff --git a/examples/n312733.jpg b/examples/n312733.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c51131ee5879050a7918bdb56675169bb227f302 Binary files /dev/null and b/examples/n312733.jpg differ diff --git a/examples/n322129.jpg b/examples/n322129.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c600ecb53f9572266a7eea484a3af34dbf65d69d Binary files /dev/null and b/examples/n322129.jpg differ diff --git a/examples/n326280.jpg b/examples/n326280.jpg new file mode 100644 index 0000000000000000000000000000000000000000..54736388851de12d086bb0e05edb43d79fbf5f4e Binary files /dev/null and b/examples/n326280.jpg differ diff --git a/examples/n330963.jpg b/examples/n330963.jpg new file mode 100644 index 0000000000000000000000000000000000000000..83b80838230e9c327feb6287d129472081e735f4 Binary files /dev/null and b/examples/n330963.jpg differ diff --git a/examples/n331938.jpg b/examples/n331938.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ebd18353a61ed02a87b297c90c3a23e0d49ec6c3 Binary files /dev/null and b/examples/n331938.jpg differ diff --git a/examples/n347432.jpg b/examples/n347432.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c465fc93f4574bbfc9d234f8869872b845c9de93 Binary files /dev/null and b/examples/n347432.jpg differ diff --git a/examples/n350699.jpg b/examples/n350699.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2591f9a83a6ecf1ef1ade1dcb7e8f49eb91bbcdc Binary files /dev/null and b/examples/n350699.jpg differ diff --git a/examples/n365465.jpg b/examples/n365465.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c4d0b759fc00571e180c60ae23772e04ea859d98 Binary files /dev/null and b/examples/n365465.jpg differ diff --git a/examples/n374566.jpg b/examples/n374566.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0925f0b4d8f5dad6480b965ed6ea3e67595ab624 Binary files /dev/null and b/examples/n374566.jpg differ diff --git a/examples/n38311.jpg b/examples/n38311.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0e50f906f61525ffee6efe145093a31785bc84a2 Binary files /dev/null and b/examples/n38311.jpg differ diff --git a/examples/n38422.jpg b/examples/n38422.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b84ba475c5c39b19a5a81c7cd76bcaa916b533b9 Binary files /dev/null and b/examples/n38422.jpg differ diff --git a/examples/n392414.jpg b/examples/n392414.jpg new file mode 100644 index 0000000000000000000000000000000000000000..990ef014a37a78e411bacc5b8f155bddeaab8603 Binary files /dev/null and b/examples/n392414.jpg differ diff --git a/examples/n395593.jpg b/examples/n395593.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fdfeb9997c5037fe3c54b8280c02543dd41d0f56 Binary files /dev/null and b/examples/n395593.jpg differ diff --git a/examples/n399930.jpg b/examples/n399930.jpg new file mode 100644 index 0000000000000000000000000000000000000000..46218cefb0363fa029187aea2daa08c7132ca4fc Binary files /dev/null and b/examples/n399930.jpg differ diff --git a/examples/n410473.jpg b/examples/n410473.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7407c756feb6c83f4354e8b95e38fa5085dfc718 Binary files /dev/null and b/examples/n410473.jpg differ diff --git a/examples/n412553.jpg b/examples/n412553.jpg new file mode 100644 index 0000000000000000000000000000000000000000..121d2d563bc39d28ef246c95e6ec0d0404900793 Binary files /dev/null and b/examples/n412553.jpg differ diff --git a/examples/n416031.jpg b/examples/n416031.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d0f443955ec96ab069898e20a85e7c516f3cbf54 Binary files /dev/null and b/examples/n416031.jpg differ diff --git a/examples/n4164.jpg b/examples/n4164.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8137d6ee52c400d5227053ac20c1d7b2775eac0c Binary files /dev/null and b/examples/n4164.jpg differ diff --git a/examples/n427806.jpg b/examples/n427806.jpg new file mode 100644 index 0000000000000000000000000000000000000000..531077e52548aad54297c7219ad3f5d6d82dada6 Binary files /dev/null and b/examples/n427806.jpg differ diff --git a/examples/n430186.jpg b/examples/n430186.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5dd928a94d5f2a474beebedf4b896c4589675d96 Binary files /dev/null and b/examples/n430186.jpg differ diff --git a/examples/n435353.jpg b/examples/n435353.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5cccb087a9d3079b6d1f0ffce1b57cc8e61c43ab Binary files /dev/null and b/examples/n435353.jpg differ diff --git a/examples/n440026.jpg b/examples/n440026.jpg new file mode 100644 index 0000000000000000000000000000000000000000..20ab4573a3bbfa5f243f49b4f6e91394f1a04781 Binary files /dev/null and b/examples/n440026.jpg differ diff --git a/examples/n458063.jpg b/examples/n458063.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d9acc5b2156f99b148e82e12f33776a14fcb962d Binary files /dev/null and b/examples/n458063.jpg differ diff --git a/examples/n459656.jpg b/examples/n459656.jpg new file mode 100644 index 0000000000000000000000000000000000000000..77db6969fb0e6b05418b037b98a60c08bdf22a92 Binary files /dev/null and b/examples/n459656.jpg differ diff --git a/examples/n459958.jpg b/examples/n459958.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6988771afdc37b753a60ddae4dd77f48830d35e2 Binary files /dev/null and b/examples/n459958.jpg differ diff --git a/examples/n460670.jpg b/examples/n460670.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ac9d893e283e1f956aab26e4f21f1a6e71b3af1 Binary files /dev/null and b/examples/n460670.jpg differ diff --git a/examples/n470864.jpg b/examples/n470864.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1e3d6fcee507f909366daf81e417f5e4ed2e5a56 Binary files /dev/null and b/examples/n470864.jpg differ diff --git a/examples/n471447.jpg b/examples/n471447.jpg new file mode 100644 index 0000000000000000000000000000000000000000..394932fd4049ccd352adbe002274ac5ed5b9968a Binary files /dev/null and b/examples/n471447.jpg differ diff --git a/examples/n479684.jpg b/examples/n479684.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6575ec2ae7a7dd5855ef8fdde7a0db6d277ad704 Binary files /dev/null and b/examples/n479684.jpg differ diff --git a/examples/n480635.jpg b/examples/n480635.jpg new file mode 100644 index 0000000000000000000000000000000000000000..82aeb07f82525102c53b9a67e840c3c7fba837f4 Binary files /dev/null and b/examples/n480635.jpg differ diff --git a/examples/n481388.jpg b/examples/n481388.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f9e477bc842c9314d96b60025b3dfc24b01b9bc2 Binary files /dev/null and b/examples/n481388.jpg differ diff --git a/examples/n483057.jpg b/examples/n483057.jpg new file mode 100644 index 0000000000000000000000000000000000000000..382e366024f4ea2e694fd164439e4333c9e76bce Binary files /dev/null and b/examples/n483057.jpg differ diff --git a/examples/n487782.jpg b/examples/n487782.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8c5b68b7c6c12fcdf3a86f86815cc7a30ae319d6 Binary files /dev/null and b/examples/n487782.jpg differ diff --git a/examples/n488640.jpg b/examples/n488640.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e98408b838af94ae7dc7589e41a9943cfb795254 Binary files /dev/null and b/examples/n488640.jpg differ diff --git a/examples/n488826.jpg b/examples/n488826.jpg new file mode 100644 index 0000000000000000000000000000000000000000..df8a54b86856e94c1c7daf590edf16a5cadb9686 Binary files /dev/null and b/examples/n488826.jpg differ diff --git a/examples/n502688.jpg b/examples/n502688.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ae4f0d86a7098dd86936ee763fc3d51888094f67 Binary files /dev/null and b/examples/n502688.jpg differ diff --git a/examples/n506031.jpg b/examples/n506031.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bb68eeab8e50b0e54b5d490ed65af507ec7f1d4f Binary files /dev/null and b/examples/n506031.jpg differ diff --git a/examples/n513785.jpg b/examples/n513785.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b06cc3a7a981b6716515c6f92730d4ff2b6a7f0c Binary files /dev/null and b/examples/n513785.jpg differ diff --git a/examples/n51905.jpg b/examples/n51905.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b8e7ad4221367b6674f5f5dbd1127a6eb275ac95 Binary files /dev/null and b/examples/n51905.jpg differ diff --git a/examples/n528163.jpg b/examples/n528163.jpg new file mode 100644 index 0000000000000000000000000000000000000000..40379235e27336d1d3bac50cd8d4b438f74eb62b Binary files /dev/null and b/examples/n528163.jpg differ diff --git a/examples/n546419.jpg b/examples/n546419.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c5a98f67352d1db819be4be6a9a828b241324394 Binary files /dev/null and b/examples/n546419.jpg differ diff --git a/examples/n550531.jpg b/examples/n550531.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2605724156c79ea6a3a1bae975298e6fcae97279 Binary files /dev/null and b/examples/n550531.jpg differ diff --git a/examples/n557532.jpg b/examples/n557532.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3f2da21048b4bcac8c96781a773bbc2dfc88a615 Binary files /dev/null and b/examples/n557532.jpg differ diff --git a/examples/n561498.jpg b/examples/n561498.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e5352ab10da0f2ba0c6e1b8937977bddfbf7cedb Binary files /dev/null and b/examples/n561498.jpg differ diff --git a/examples/n563684.jpg b/examples/n563684.jpg new file mode 100644 index 0000000000000000000000000000000000000000..beb16b7e1f0329cb737fe6d06274dd9a001f2c90 Binary files /dev/null and b/examples/n563684.jpg differ diff --git a/examples/n56870.jpg b/examples/n56870.jpg new file mode 100644 index 0000000000000000000000000000000000000000..25218d69850c3dfbd83df6742dd433e2d38fd69b Binary files /dev/null and b/examples/n56870.jpg differ diff --git a/examples/n578685.jpg b/examples/n578685.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b0b13f0392a714ad2b5e41028c9272ec560d39d9 Binary files /dev/null and b/examples/n578685.jpg differ diff --git a/examples/n580875.jpg b/examples/n580875.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bf6ff75e60010d06b20a3e6c2b9489ec96755402 Binary files /dev/null and b/examples/n580875.jpg differ diff --git a/examples/n6492.jpg b/examples/n6492.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d92bb26c7a1b871e0c87e45b81d618cdfeca1dbd Binary files /dev/null and b/examples/n6492.jpg differ diff --git a/examples/n68057.jpg b/examples/n68057.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c1c652578f8260f5f7427c016b26d2196c19e158 Binary files /dev/null and b/examples/n68057.jpg differ diff --git a/examples/n68923.jpg b/examples/n68923.jpg new file mode 100644 index 0000000000000000000000000000000000000000..23e88f87521f100f495434a4fb6463d35e342e6d Binary files /dev/null and b/examples/n68923.jpg differ diff --git a/examples/n98951.jpg b/examples/n98951.jpg new file mode 100644 index 0000000000000000000000000000000000000000..91202b6dd54a9174ed56d361e61110a2d9c816c4 Binary files /dev/null and b/examples/n98951.jpg differ diff --git a/examples/questions.json b/examples/questions.json new file mode 100644 index 0000000000000000000000000000000000000000..dec32ce8c061eb14304c004719675bf0d1af9fe8 --- /dev/null +++ b/examples/questions.json @@ -0,0 +1 @@ +[{"isBalanced": true, "question": "Is the ice cream near the other ice cream closed or open?", "imageId": "n111074", "key": "20701458"}, {"isBalanced": true, "question": "What is standing at the table?", "imageId": "n6492", "key": "20870679"}, {"isBalanced": true, "question": "Are all the people female?", "imageId": "n312733", "key": "201478268"}, {"isBalanced": true, "question": "How large is the plate the fork is on?", "imageId": "n196522", "key": "202217814"}, {"isBalanced": true, "question": "Are there lambs or cows that are lying?", "imageId": "n11399", "key": "201525764"}, {"isBalanced": true, "question": "What is the batter standing on?", "imageId": "n238886", "key": "20297136"}, {"isBalanced": true, "question": "What dessert is to the right of the blueberry?", "imageId": "n331938", "key": "202021868"}, {"isBalanced": true, "question": "Is it an indoors scene?", "imageId": "n178654", "key": "201239127"}, {"isBalanced": true, "question": "What is the animal in front of the fence?", "imageId": "n416031", "key": "201295570"}, {"isBalanced": true, "question": "Are the cards small or large?", "imageId": "n155638", "key": "20272426"}, {"isBalanced": true, "question": "Is there any sand that is dry?", "imageId": "n116868", "key": "202085903"}, {"isBalanced": true, "question": "Which kind of furniture is elevated?", "imageId": "n546419", "key": "201687149"}, {"isBalanced": true, "question": "What is in front of the sink that is not little?", "imageId": "n272929", "key": "201737621"}, {"isBalanced": true, "question": "Is the shirt huge or maybe small?", "imageId": "n140477", "key": "20339922"}, {"isBalanced": true, "question": "How clean is the table that is made of wood?", "imageId": "n137739", "key": "202036180"}, {"isBalanced": true, "question": "What kind of animal is in front of the mirror?", "imageId": "n168871", "key": "201327472"}, {"isBalanced": true, "question": "Is the surfboard light and blue?", "imageId": "n173361", "key": "201766919"}, {"isBalanced": true, "question": "Does the mirror on the side of the truck look round and small?", "imageId": "n38311", "key": "20463054"}, {"isBalanced": true, "question": "Is the plate heavy and blue?", "imageId": "n241451", "key": "20369375"}, {"isBalanced": true, "question": "Is the pepper in front of the stairs that are made of metal?", "imageId": "n151233", "key": "201722741"}, {"isBalanced": true, "question": "Is the spoon hard?", "imageId": "n470864", "key": "2049585"}, {"isBalanced": true, "question": "What is the rock resting on?", "imageId": "n506031", "key": "201198326"}, {"isBalanced": true, "question": "What device is watching the person that is in front of the building?", "imageId": "n291937", "key": "20919560"}, {"isBalanced": true, "question": "What is in front of the man?", "imageId": "n459656", "key": "201147138"}, {"isBalanced": true, "question": "What is the item of furniture in front of the white wall?", "imageId": "n179572", "key": "20985605"}, {"isBalanced": true, "question": "Is that shirt long sleeved or short sleeved?", "imageId": "n38422", "key": "201101201"}, {"isBalanced": true, "question": "Who is waiting for the vehicle to the left of the police officer?", "imageId": "n115850", "key": "201296804"}, {"isBalanced": true, "question": "Who is standing on the ground?", "imageId": "n480635", "key": "20245327"}, {"isBalanced": true, "question": "How clean are the shoes?", "imageId": "n410473", "key": "20535372"}, {"isBalanced": true, "question": "Are there pots on top of the large appliance?", "imageId": "n322129", "key": "20824703"}, {"isBalanced": true, "question": "Which kind of furniture is black?", "imageId": "n310594", "key": "201728631"}, {"isBalanced": true, "question": "Does that snow board look smooth?", "imageId": "n177259", "key": "20132647"}, {"isBalanced": true, "question": "Which kind of furniture is made of leather?", "imageId": "n116797", "key": "201938506"}, {"isBalanced": true, "question": "Are there both umpires and baseballs in the photo?", "imageId": "n4164", "key": "201965886"}, {"isBalanced": true, "question": "Does the sheep near the rock look male?", "imageId": "n176076", "key": "201764175"}, {"isBalanced": true, "question": "How big is the motorbike?", "imageId": "n502688", "key": "201712414"}, {"isBalanced": true, "question": "Is the mouse made of the same material as the table?", "imageId": "n561498", "key": "201162711"}, {"isBalanced": true, "question": "What is underneath the shelf made of wood?", "imageId": "n193989", "key": "20364571"}, {"isBalanced": true, "question": "What the long piece of furniture is called?", "imageId": "n580875", "key": "202125126"}, {"isBalanced": true, "question": "Is that man tall and happy?", "imageId": "n365465", "key": "202056616"}, {"isBalanced": true, "question": "What are the pieces of furniture that are wooden?", "imageId": "n487782", "key": "20430925"}, {"isBalanced": true, "question": "Is the chair that is to the right of the woman made of cloth or leather?", "imageId": "n14897", "key": "201911004"}, {"isBalanced": true, "question": "Is the baseball bat yellow?", "imageId": "n395593", "key": "202141201"}, {"isBalanced": true, "question": "Which kind of clothing is simple?", "imageId": "n350699", "key": "201307631"}, {"isBalanced": true, "question": "In front of what type of appliance is she?", "imageId": "n330963", "key": "201226829"}, {"isBalanced": true, "question": "What fruit is in the whipped cream?", "imageId": "n25398", "key": "20806900"}, {"isBalanced": true, "question": "What is common to the name tag and the sticker?", "imageId": "n241713", "key": "202054197"}, {"isBalanced": true, "question": "Which kind of device is on the desk?", "imageId": "n194711", "key": "201072681"}, {"isBalanced": true, "question": "Which kind of animal is in front of the man?", "imageId": "n488640", "key": "201246286"}, {"isBalanced": true, "question": "What do both the seat and the sidewalk have in common?", "imageId": "n2381", "key": "20504549"}, {"isBalanced": true, "question": "What kind of furniture is in front of the desk?", "imageId": "n132998", "key": "20546526"}, {"isBalanced": true, "question": "Who is wearing the dress?", "imageId": "n326280", "key": "201397977"}, {"isBalanced": true, "question": "Where is the gravel?", "imageId": "n374566", "key": "2010307"}, {"isBalanced": true, "question": "How clean is the toilet seat that looks white?", "imageId": "n578685", "key": "20593080"}, {"isBalanced": true, "question": "Is there either a mouse or a keyboard that is not silver?", "imageId": "n458063", "key": "201271121"}, {"isBalanced": true, "question": "What is the coffee table in front of?", "imageId": "n288083", "key": "202188752"}, {"isBalanced": true, "question": "How large is the shirt that is white?", "imageId": "n241130", "key": "201456944"}, {"isBalanced": true, "question": "How big is the fork?", "imageId": "n483057", "key": "201614508"}, {"isBalanced": true, "question": "Is the island small and square?", "imageId": "n305396", "key": "201517286"}, {"isBalanced": true, "question": "Does the television to the right of the movies seem to be on?", "imageId": "n154501", "key": "201580213"}, {"isBalanced": true, "question": "Are both the mirror and the art work made of the same material?", "imageId": "n430186", "key": "201472818"}, {"isBalanced": true, "question": "Who is wearing a helmet?", "imageId": "n481388", "key": "201632901"}, {"isBalanced": true, "question": "What is the item of furniture that is behind the person that wears glasses?", "imageId": "n399930", "key": "20416127"}, {"isBalanced": true, "question": "Which material makes up the tray, aluminum or plastic?", "imageId": "n113863", "key": "201855573"}, {"isBalanced": true, "question": "What's sitting next to the faucet?", "imageId": "n392414", "key": "201139357"}, {"isBalanced": true, "question": "Are the snow shoes black or white?", "imageId": "n24680", "key": "202184758"}, {"isBalanced": true, "question": "The cotton blanket is what color?", "imageId": "n188669", "key": "20124226"}, {"isBalanced": true, "question": "On which side is the short curtain?", "imageId": "n460670", "key": "20363328"}, {"isBalanced": true, "question": "What is the color of the console?", "imageId": "n563684", "key": "202070638"}, {"isBalanced": true, "question": "Is the mirror that is not big both old fashioned and metallic?", "imageId": "n279408", "key": "2021361"}, {"isBalanced": true, "question": "What's the horse in?", "imageId": "n173931", "key": "20661087"}, {"isBalanced": true, "question": "The boat is where?", "imageId": "n459958", "key": "2019723"}, {"isBalanced": true, "question": "What is the man wearing?", "imageId": "n513785", "key": "201524455"}, {"isBalanced": true, "question": "How does the bicycle behind the table look like, modern or antique?", "imageId": "n440026", "key": "201164906"}, {"isBalanced": true, "question": "Is it outdoors?", "imageId": "n209769", "key": "202132257"}, {"isBalanced": true, "question": "What is the device that is to the left of the person that is wearing a backpack?", "imageId": "n256710", "key": "201001088"}, {"isBalanced": true, "question": "What is the long sleeved clothing item?", "imageId": "n471447", "key": "20833702"}, {"isBalanced": true, "question": "What's the bread in front of?", "imageId": "n177566", "key": "201068990"}, {"isBalanced": true, "question": "Which kind of furniture is below the monitor?", "imageId": "n550531", "key": "201934863"}, {"isBalanced": true, "question": "Does the person to the left of the tennis racket wear a hat?", "imageId": "n210898", "key": "201122640"}, {"isBalanced": true, "question": "What animal is standing on top of the surfboard?", "imageId": "n98951", "key": "201083292"}, {"isBalanced": true, "question": "What is she doing?", "imageId": "n278426", "key": "20648831"}, {"isBalanced": true, "question": "Are the red vegetables on top of the square plate?", "imageId": "n222443", "key": "20904479"}, {"isBalanced": true, "question": "How hard is the bench?", "imageId": "n347432", "key": "20565248"}, {"isBalanced": true, "question": "Is the apron in the image worn on the woman that wears glasses?", "imageId": "n412553", "key": "201955363"}, {"isBalanced": true, "question": "How long are the draperies to the right of the computer mouse?", "imageId": "n479684", "key": "201172637"}, {"isBalanced": true, "question": "What is located on top of the dessert that is on top of the plate?", "imageId": "n528163", "key": "201393357"}, {"isBalanced": true, "question": "How are the animals that are on top of the beige object that is in front of the wall called?", "imageId": "n183744", "key": "20103522"}, {"isBalanced": true, "question": "What is the grass in front of?", "imageId": "n249342", "key": "202152143"}, {"isBalanced": true, "question": "What is the device that he is holding?", "imageId": "n56870", "key": "201418388"}, {"isBalanced": true, "question": "What is the table made of?", "imageId": "n68923", "key": "201694851"}, {"isBalanced": true, "question": "What is the dessert that looks small sitting atop?", "imageId": "n282460", "key": "20111491"}, {"isBalanced": true, "question": "What appliance is white?", "imageId": "n427806", "key": "201751299"}, {"isBalanced": true, "question": "What is the boy that is not uncomfortable doing?", "imageId": "n68057", "key": "20870936"}, {"isBalanced": true, "question": "What animals are it?", "imageId": "n557532", "key": "20578852"}, {"isBalanced": true, "question": "Is the camera large or small?", "imageId": "n306745", "key": "201581894"}, {"isBalanced": true, "question": "Which kind of furniture is sturdy?", "imageId": "n51905", "key": "201403523"}, {"isBalanced": true, "question": "What is attached to the animal that is big?", "imageId": "n308857", "key": "20221798"}, {"isBalanced": true, "question": "Which kind of furniture is not sliding, the drawer or the desk?", "imageId": "n488826", "key": "201889747"}, {"isBalanced": true, "question": "Is the white towel hanging from the oven?", "imageId": "n435353", "key": "20563573"}] \ No newline at end of file diff --git a/image_patch.py b/image_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..49c5262cb72524a522dea249b86366f8f9e339bf --- /dev/null +++ b/image_patch.py @@ -0,0 +1,520 @@ +from __future__ import annotations + +import json +import pathlib +import re +from typing import Tuple +from typing import Union, List + +import numpy as np +import torch +from PIL import Image +from dateutil import parser as dateparser +from torchvision import transforms +from torchvision.ops import box_iou +from word2number import w2n + +from vision_processes import forward + + +def load_json(path: str): + if isinstance(path, str): + path = pathlib.Path(path) + if path.suffix != '.json': + path = path.with_suffix('.json') + with open(path, 'r') as f: + data = json.load(f) + return data + + +class ImagePatch: + """A Python class containing a crop of an image centered around a particular object, as well as relevant + information. + Attributes + ---------- + cropped_image : array_like + An array-like of the cropped image taken from the original image. + left : int + An int describing the position of the left border of the crop's bounding box in the original image. + lower : int + An int describing the position of the bottom border of the crop's bounding box in the original image. + right : int + An int describing the position of the right border of the crop's bounding box in the original image. + upper : int + An int describing the position of the top border of the crop's bounding box in the original image. + + Methods + ------- + find(object_name: str)->List[ImagePatch] + Returns a list of new ImagePatch objects containing crops of the image centered around any objects found in the + image matching the object_name. + exists(object_name: str)->bool + Returns True if the object specified by object_name is found in the image, and False otherwise. + verify_property(property: str)->bool + Returns True if the property is met, and False otherwise. + best_text_match(option_list: List[str], prefix: str)->str + Returns the string that best matches the image. + simple_query(question: str=None)->str + Returns the answer to a basic question asked about the image. If no question is provided, returns the answer + to "What is this?". + compute_depth()->float + Returns the median depth of the image crop. + crop(left: int, lower: int, right: int, upper: int)->ImagePatch + Returns a new ImagePatch object containing a crop of the image at the given coordinates. + """ + + def __init__(self, image: Union[Image.Image, torch.Tensor, np.ndarray], left: int = None, lower: int = None, + right: int = None, upper: int = None, parent_left=0, parent_lower=0, queues=None, + parent_img_patch=None): + """Initializes an ImagePatch object by cropping the image at the given coordinates and stores the coordinates as + attributes. If no coordinates are provided, the image is left unmodified, and the coordinates are set to the + dimensions of the image. + + Parameters + ------- + image : array_like + An array-like of the original image. + left : int + An int describing the position of the left border of the crop's bounding box in the original image. + lower : int + An int describing the position of the bottom border of the crop's bounding box in the original image. + right : int + An int describing the position of the right border of the crop's bounding box in the original image. + upper : int + An int describing the position of the top border of the crop's bounding box in the original image. + + """ + + if isinstance(image, Image.Image): + image = transforms.ToTensor()(image) + elif isinstance(image, np.ndarray): + image = torch.tensor(image).permute(1, 2, 0) + elif isinstance(image, torch.Tensor) and image.dtype == torch.uint8: + image = image / 255 + + if left is None and right is None and upper is None and lower is None: + self.cropped_image = image + self.left = 0 + self.lower = 0 + self.right = image.shape[2] # width + self.upper = image.shape[1] # height + else: + self.cropped_image = image[:, image.shape[1] - upper:image.shape[1] - lower, left:right] + self.left = left + parent_left + self.upper = upper + parent_lower + self.right = right + parent_left + self.lower = lower + parent_lower + + self.height = self.cropped_image.shape[1] + self.width = self.cropped_image.shape[2] + + self.cache = {} + self.queues = (None, None) if queues is None else queues + + self.parent_img_patch = parent_img_patch + + self.horizontal_center = (self.left + self.right) / 2 + self.vertical_center = (self.lower + self.upper) / 2 + + if self.cropped_image.shape[1] == 0 or self.cropped_image.shape[2] == 0: + raise Exception("ImagePatch has no area") + + self.possible_options = load_json('./useful_lists/possible_options.json') + + def forward(self, model_name, *args, **kwargs): + return forward(model_name, *args, **kwargs) + # return forward(model_name, *args, queues=self.queues, **kwargs) + + @property + def original_image(self): + if self.parent_img_patch is None: + return self.cropped_image + else: + return self.parent_img_patch.original_image + + def find(self, object_name: str, confidence_threshold: float = None, return_confidence: bool = False) -> List: + """Returns a list of ImagePatch objects matching object_name contained in the crop if any are found. + Otherwise, returns an empty list. + Parameters + ---------- + object_name : str + the name of the object to be found + + Returns + ------- + List[ImagePatch] + a list of ImagePatch objects matching object_name contained in the crop + """ + if confidence_threshold is not None: + confidence_threshold = float(confidence_threshold) + + if object_name in ["object", "objects"]: + all_object_coordinates, all_object_scores = self.forward('maskrcnn', self.cropped_image, + confidence_threshold=confidence_threshold) + all_object_coordinates = all_object_coordinates[0] + all_object_scores = all_object_scores[0] + else: + if object_name == 'person': + object_name = 'people' # GLIP does better at people than person + + all_object_coordinates, all_object_scores = self.forward('glip', self.cropped_image, object_name, + confidence_threshold=confidence_threshold) + if len(all_object_coordinates) == 0: + return [] + + threshold = 0.0 + if threshold > 0: + area_im = self.width * self.height + all_areas = torch.tensor([(coord[2] - coord[0]) * (coord[3] - coord[1]) / area_im + for coord in all_object_coordinates]) + mask = all_areas > threshold + # if not mask.any(): + # mask = all_areas == all_areas.max() # At least return one element + all_object_coordinates = all_object_coordinates[mask] + all_object_scores = all_object_scores[mask] + + boxes = [self.crop(*coordinates) for coordinates in all_object_coordinates] + if return_confidence: + return [(box, float(score)) for box, score in zip(boxes, all_object_scores.reshape(-1))] + else: + return boxes + + def exists(self, object_name) -> bool: + """Returns True if the object specified by object_name is found in the image, and False otherwise. + Parameters + ------- + object_name : str + A string describing the name of the object to be found in the image. + """ + if object_name.isdigit() or object_name.lower().startswith("number"): + object_name = object_name.lower().replace("number", "").strip() + + object_name = w2n.word_to_num(object_name) + answer = self.simple_query("What number is written in the image (in digits)?") + return w2n.word_to_num(answer) == object_name + + patches = self.find(object_name) + + filtered_patches = [] + for patch in patches: + if "yes" in patch.simple_query(f"Is this a {object_name}?"): + filtered_patches.append(patch) + return len(filtered_patches) > 0 + + def _score(self, category: str, negative_categories=None, model='clip') -> float: + """ + Returns a binary score for the similarity between the image and the category. + The negative categories are used to compare to (score is relative to the scores of the negative categories). + """ + if model == 'clip': + res = self.forward('clip', self.cropped_image, category, task='score', + negative_categories=negative_categories) + elif model == 'tcl': + res = self.forward('tcl', self.cropped_image, category, task='score') + else: # xvlm + task = 'binary_score' if negative_categories is not None else 'score' + res = self.forward('xvlm', self.cropped_image, category, task=task, negative_categories=negative_categories) + res = res.item() + + return res + + def _detect(self, category: str, thresh, negative_categories=None, model='clip') -> Tuple[bool, float]: + score = self._score(category, negative_categories, model) + return score > thresh, float(score) + + def verify_property(self, object_name: str, attribute: str, return_confidence: bool = False): + """Returns True if the object possesses the property, and False otherwise. + Differs from 'exists' in that it presupposes the existence of the object specified by object_name, instead + checking whether the object possesses the property. + Parameters + ------- + object_name : str + A string describing the name of the object to be found in the image. + attribute : str + A string describing the property to be checked. + """ + name = f"{attribute} {object_name}" + model = "xvlm" + negative_categories = [f"{att} {object_name}" for att in self.possible_options['attributes']] + # if model == 'clip': + # ret, score = self._detect(name, negative_categories=negative_categories, + # thresh=config.verify_property.thresh_clip, model='clip') + # elif model == 'tcl': + # ret, score = self._detect(name, thresh=config.verify_property.thresh_tcl, model='tcl') + # else: # 'xvlm' + ret, score = self._detect(name, negative_categories=negative_categories, thresh=0.6, model='xvlm') + + if return_confidence: + return ret, score + else: + return ret + + def best_text_match(self, option_list: list[str] = None, prefix: str = None) -> str: + """Returns the string that best matches the image. + Parameters + ------- + option_list : str + A list with the names of the different options + prefix : str + A string with the prefixes to append to the options + """ + option_list_to_use = option_list + if prefix is not None: + option_list_to_use = [prefix + " " + option for option in option_list] + + model_name = "xvlm" + image = self.cropped_image + text = option_list_to_use + if model_name in ('clip', 'tcl'): + selected = self.forward(model_name, image, text, task='classify') + elif model_name == 'xvlm': + res = self.forward(model_name, image, text, task='score') + res = res.argmax().item() + selected = res + else: + raise NotImplementedError + + return option_list[selected] + + def simple_query(self, question: str, return_confidence: bool = False): + """Returns the answer to a basic question asked about the image. If no question is provided, returns the answer + to "What is this?". The questions are about basic perception, and are not meant to be used for complex reasoning + or external knowledge. + Parameters + ------- + question : str + A string describing the question to be asked. + """ + text, score = self.forward('blip', self.cropped_image, question, task='qa') + if return_confidence: + return text, score + else: + return text + + def compute_depth(self): + """Returns the median depth of the image crop + Parameters + ---------- + Returns + ------- + float + the median depth of the image crop + """ + original_image = self.original_image + depth_map = self.forward('depth', original_image) + depth_map = depth_map[original_image.shape[1] - self.upper:original_image.shape[1] - self.lower, + self.left:self.right] + return depth_map.median() # Ideally some kind of mode, but median is good enough for now + + def crop(self, left: int, lower: int, right: int, upper: int) -> ImagePatch: + """Returns a new ImagePatch containing a crop of the original image at the given coordinates. + Parameters + ---------- + left : int + the position of the left border of the crop's bounding box in the original image + lower : int + the position of the bottom border of the crop's bounding box in the original image + right : int + the position of the right border of the crop's bounding box in the original image + upper : int + the position of the top border of the crop's bounding box in the original image + + Returns + ------- + ImagePatch + a new ImagePatch containing a crop of the original image at the given coordinates + """ + # make all inputs ints + left = int(left) + lower = int(lower) + right = int(right) + upper = int(upper) + + if True: + left = max(0, left - 10) + lower = max(0, lower - 10) + right = min(self.width, right + 10) + upper = min(self.height, upper + 10) + + return ImagePatch(self.cropped_image, left, lower, right, upper, self.left, self.lower, queues=self.queues, + parent_img_patch=self) + + def overlaps_with(self, left, lower, right, upper): + """Returns True if a crop with the given coordinates overlaps with this one, + else False. + Parameters + ---------- + left : int + the left border of the crop to be checked + lower : int + the lower border of the crop to be checked + right : int + the right border of the crop to be checked + upper : int + the upper border of the crop to be checked + + Returns + ------- + bool + True if a crop with the given coordinates overlaps with this one, else False + """ + return self.left <= right and self.right >= left and self.lower <= upper and self.upper >= lower + + def llm_query(self, question: str, long_answer: bool = True) -> str: + return llm_query(question, None, long_answer) + + # def print_image(self, size: tuple[int, int] = None): + # show_single_image(self.cropped_image, size) + + def __repr__(self): + return "ImagePatch(left={}, right={}, upper={}, lower={}, height={}, width={}, horizontal_center={}, vertical_center={})".format( + self.left, self.right, self.upper, self.lower, self.height, self.width, + self.horizontal_center, self.vertical_center, + ) + # return "ImagePatch({}, {}, {}, {})".format(self.left, self.lower, self.right, self.upper) + + +def best_image_match(list_patches: list[ImagePatch], content: List[str], return_index: bool = False) -> \ + Union[ImagePatch, None]: + """Returns the patch most likely to contain the content. + Parameters + ---------- + list_patches : List[ImagePatch] + content : List[str] + the object of interest + return_index : bool + if True, returns the index of the patch most likely to contain the object + + Returns + ------- + int + Patch most likely to contain the object + """ + if len(list_patches) == 0: + return None + + model = "xvlm" + + scores = [] + for cont in content: + if model == 'clip': + res = list_patches[0].forward(model, [p.cropped_image for p in list_patches], cont, task='compare', + return_scores=True) + else: + res = list_patches[0].forward(model, [p.cropped_image for p in list_patches], cont, task='score') + scores.append(res) + scores = torch.stack(scores).mean(dim=0) + scores = scores.argmax().item() # Argmax over all image patches + + if return_index: + return scores + return list_patches[scores] + + +def distance(patch_a: Union[ImagePatch, float], patch_b: Union[ImagePatch, float]) -> float: + """ + Returns the distance between the edges of two ImagePatches, or between two floats. + If the patches overlap, it returns a negative distance corresponding to the negative intersection over union. + """ + + if isinstance(patch_a, ImagePatch) and isinstance(patch_b, ImagePatch): + a_min = np.array([patch_a.left, patch_a.lower]) + a_max = np.array([patch_a.right, patch_a.upper]) + b_min = np.array([patch_b.left, patch_b.lower]) + b_max = np.array([patch_b.right, patch_b.upper]) + + u = np.maximum(0, a_min - b_max) + v = np.maximum(0, b_min - a_max) + + dist = np.sqrt((u ** 2).sum() + (v ** 2).sum()) + + if dist == 0: + box_a = torch.tensor([patch_a.left, patch_a.lower, patch_a.right, patch_a.upper])[None] + box_b = torch.tensor([patch_b.left, patch_b.lower, patch_b.right, patch_b.upper])[None] + dist = - box_iou(box_a, box_b).item() + + else: + dist = abs(patch_a - patch_b) + + return dist + + +def bool_to_yesno(bool_answer: bool) -> str: + """Returns a yes/no answer to a question based on the boolean value of bool_answer. + Parameters + ---------- + bool_answer : bool + a boolean value + + Returns + ------- + str + a yes/no answer to a question based on the boolean value of bool_answer + """ + return "yes" if bool_answer else "no" + + +def llm_query(query, context=None, long_answer=True, queues=None): + """Answers a text question using GPT-3. The input question is always a formatted string with a variable in it. + + Parameters + ---------- + query: str + the text question to ask. Must not contain any reference to 'the image' or 'the photo', etc. + """ + if long_answer: + return forward(model_name='gpt3_general', prompt=query, queues=queues) + else: + return forward(model_name='gpt3_qa', prompt=[query, context], queues=queues) + + +def process_guesses(prompt, guess1=None, guess2=None, queues=None): + return forward(model_name='gpt3_guess', prompt=[prompt, guess1, guess2], queues=queues) + + +def coerce_to_numeric(string, no_string=False): + """ + This function takes a string as input and returns a numeric value after removing any non-numeric characters. + If the input string contains a range (e.g. "10-15"), it returns the first value in the range. + # TODO: Cases like '25to26' return 2526, which is not correct. + """ + if any(month in string.lower() for month in ['january', 'february', 'march', 'april', 'may', 'june', 'july', + 'august', 'september', 'october', 'november', 'december']): + try: + return dateparser.parse(string).timestamp().year + except: # Parse Error + pass + + try: + # If it is a word number (e.g. 'zero') + numeric = w2n.word_to_num(string) + return numeric + except ValueError: + pass + + # Remove any non-numeric characters except the decimal point and the negative sign + string_re = re.sub("[^0-9\.\-]", "", string) + + if string_re.startswith('-'): + string_re = '&' + string_re[1:] + + # Check if the string includes a range + if "-" in string_re: + # Split the string into parts based on the dash character + parts = string_re.split("-") + return coerce_to_numeric(parts[0].replace('&', '-')) + else: + string_re = string_re.replace('&', '-') + + try: + # Convert the string to a float or int depending on whether it has a decimal point + if "." in string_re: + numeric = float(string_re) + else: + numeric = int(string_re) + except: + if no_string: + raise ValueError + # No numeric values. Return input + return string + return numeric diff --git a/joint.prompt b/joint.prompt new file mode 100644 index 0000000000000000000000000000000000000000..366671100481549d8e90ff519141a80f3094063b --- /dev/null +++ b/joint.prompt @@ -0,0 +1,768 @@ +from typing import List, Union + +from vision_functions import find_in_image, simple_qa, verify_property, best_text_match, compute_depth + + +def bool_to_yesno(bool_answer: bool) -> str: + return "yes" if bool_answer else "no" + + +class ImagePatch: + """A Python class containing a crop of an image centered around a particular object, as well as relevant information. + Attributes + ---------- + cropped_image : array_like + An array-like of the cropped image taken from the original image. + left : int + An int describing the position of the left border of the crop's bounding box in the original image. + lower : int + An int describing the position of the bottom border of the crop's bounding box in the original image. + right : int + An int describing the position of the right border of the crop's bounding box in the original image. + upper : int + An int describing the position of the top border of the crop's bounding box in the original image. + + Methods + ------- + find(object_name: str) -> List[ImagePatch] + Returns a list of new ImagePatch objects containing crops of the image centered around any objects found in the image matching the object_name. + simple_query(question: str=None) -> str + Returns the answer to a basic question asked about the image. If no question is provided, returns the answer to "What is this?". + exists(object_name: str) -> bool + Returns True if the object specified by object_name is found in the image, and False otherwise. + verify_property(property: str) -> bool + Returns True if the property is met, and False otherwise. + compute_depth()->float + Returns the median depth of the image crop. + best_text_match(string1: str, string2: str) -> str + Returns the string that best matches the image. + crop(left: int, lower: int, right: int, upper: int) -> ImagePatch + Returns a new ImagePatch object containing a crop of the image at the given coordinates. + """ + + def __init__(self, image, left: int = None, lower: int = None, right: int = None, upper: int = None): + """Initializes an ImagePatch object by cropping the image at the given coordinates and stores the coordinates as attributes. + If no coordinates are provided, the image is left unmodified, and the coordinates are set to the dimensions of the image. + Parameters + ------- + image : array_like + An array-like of the original image. + left : int + An int describing the position of the left border of the crop's bounding box in the original image. + lower : int + An int describing the position of the bottom border of the crop's bounding box in the original image. + right : int + An int describing the position of the right border of the crop's bounding box in the original image. + upper : int + An int describing the position of the top border of the crop's bounding box in the original image. + """ + if left is None and right is None and upper is None and lower is None: + self.cropped_image = image + self.left = 0 + self.lower = 0 + self.right = image.shape[2] # width + self.upper = image.shape[1] # height + else: + self.cropped_image = image[:, lower:upper, left:right] + self.left = left + self.upper = upper + self.right = right + self.lower = lower + + self.width = self.cropped_image.shape[2] + self.height = self.cropped_image.shape[1] + + self.horizontal_center = (self.left + self.right) / 2 + self.vertical_center = (self.lower + self.upper) / 2 + + def find(self, object_name: str) -> List["ImagePatch"]: + """Returns a new ImagePatch object containing the crop of the image centered around the object specified by object_name. + Parameters + ------- + object_name : str + A string describing the name of the object to be found in the image. + + Examples + -------- + >>> # Given an image: Find the foo. + >>> def execute_command(image) -> List[ImagePatch]: + >>> image_patch = ImagePatch(image) + >>> foo_patches = image_patch.find("foo") + >>> return foo_patches + """ + return find_in_image(self.cropped_image, object_name) + + def simple_query(self, question: str = None) -> str: + """Returns the answer to a basic question asked about the image. If no question is provided, returns the answer to "What is this?". + Parameters + ------- + question : str + A string describing the question to be asked. + + Examples + ------- + >>> # Given an image: Which kind of animal is not eating? + >>> def execute_command(image) -> str: + >>> image_patch = ImagePatch(image) + >>> animal_patches = image_patch.find("animal") + >>> for animal_patch in animal_patches: + >>> if not animal_patch.verify_property("animal", "eating"): + >>> return animal_patch.simple_query("What kind of animal is eating?") # crop would include eating so keep it in the query + >>> # If no animal is not eating, query the image directly + >>> return image_patch.simple_query("Which kind of animal is not eating?") + + >>> # Given an image: What is in front of the horse? + >>> def execute_command(image) -> str: + >>> image_patch = ImagePatch(image) + >>> # contains a relation (around, next to, on, near, on top of, in front of, behind, etc), so ask directly + >>> return image_patch.simple_query("What is in front of the horse?") + """ + return simple_qa(self.cropped_image, question) + + def exists(self, object_name: str) -> bool: + """Returns True if the object specified by object_name is found in the image, and False otherwise. + Parameters + ------- + object_name : str + A string describing the name of the object to be found in the image. + + Examples + ------- + >>> # Given an image: Are there both cakes and gummy bears in the photo? + >>> def execute_command(image) -> str: + >>> image_patch = ImagePatch(image) + >>> is_cake = image_patch.exists("cake") + >>> is_gummy_bear = image_patch.exists("gummy bear") + >>> return bool_to_yesno(is_cake and is_gummy_bear) + """ + return len(self.find(object_name)) > 0 + + def verify_property(self, object_name: str, property: str) -> bool: + """Returns True if the object possesses the property, and False otherwise. + Differs from 'exists' in that it presupposes the existence of the object specified by object_name, instead checking whether the object possesses the property. + Parameters + ------- + object_name : str + A string describing the name of the object to be found in the image. + property : str + A string describing the property to be checked. + + Examples + ------- + >>> # Given an image: Do the letters have blue color? + >>> def execute_command(image) -> str: + >>> image_patch = ImagePatch(image) + >>> letters_patches = image_patch.find("letters") + >>> # Question assumes only one letter patch + >>> if len(letters_patches) == 0: + >>> # If no letters are found, query the image directly + >>> return image_patch.simple_query("Do the letters have blue color?") + >>> return bool_to_yesno(letters_patches[0].verify_property("letters", "blue")) + """ + return verify_property(self.cropped_image, object_name, property) + + def compute_depth(self): + """Returns the median depth of the image crop + Parameters + ---------- + Returns + ------- + float + the median depth of the image crop + + Examples + -------- + >>> # Given an image: Find the bar furthest away. + >>> def execute_command(image)->ImagePatch: + >>> image_patch = ImagePatch(image) + >>> bar_patches = image_patch.find("bar") + >>> bar_patches.sort(key=lambda bar: bar.compute_depth()) + >>> return bar_patches[-1] + """ + depth_map = compute_depth(self.cropped_image) + return depth_map.median() + + def best_text_match(self, option_list: List[str]) -> str: + """Returns the string that best matches the image. + Parameters + ------- + option_list : str + A list with the names of the different options + prefix : str + A string with the prefixes to append to the options + + Examples + ------- + >>> # Given an image: Is the cap gold or white? + >>> def execute_command(image) -> str: + >>> image_patch = ImagePatch(image) + >>> cap_patches = image_patch.find("cap") + >>> # Question assumes one cap patch + >>> if len(cap_patches) == 0: + >>> # If no cap is found, query the image directly + >>> return image_patch.simple_query("Is the cap gold or white?") + >>> return cap_patches[0].best_text_match(["gold", "white"]) + """ + return best_text_match(self.cropped_image, option_list) + + def crop(self, left: int, lower: int, right: int, upper: int) -> "ImagePatch": + """Returns a new ImagePatch cropped from the current ImagePatch. + Parameters + ------- + left : int + The leftmost pixel of the cropped image. + lower : int + The lowest pixel of the cropped image. + right : int + The rightmost pixel of the cropped image. + upper : int + The uppermost pixel of the cropped image. + ------- + """ + return ImagePatch(self.cropped_image, left, lower, right, upper) + + +def best_image_match(list_patches: List[ImagePatch], content: List[str], return_index=False) -> Union[ImagePatch, int]: + """Returns the patch most likely to contain the content. + Parameters + ---------- + list_patches : List[ImagePatch] + content : List[str] + the object of interest + return_index : bool + if True, returns the index of the patch most likely to contain the object + + Returns + ------- + int + Patch most likely to contain the object + """ + return best_image_match(list_patches, content, return_index) + + +def distance(patch_a: ImagePatch, patch_b: ImagePatch) -> float: + """ + Returns the distance between the edges of two ImagePatches. If the patches overlap, it returns a negative distance + corresponding to the negative intersection over union. + + Parameters + ---------- + patch_a : ImagePatch + patch_b : ImagePatch + + Examples + -------- + # Return the qux that is closest to the foo + >>> def execute_command(image): + >>> image_patch = ImagePatch(image) + >>> qux_patches = image_patch.find('qux') + >>> foo_patches = image_patch.find('foo') + >>> foo_patch = foo_patches[0] + >>> qux_patches.sort(key=lambda x: distance(x, foo_patch)) + >>> return qux_patches[0] + """ + return distance(patch_a, patch_b) + + +# Examples of using ImagePatch + + +# Given an image: What toy is wearing a shirt? +def execute_command(image) -> str: + # not a relational verb so go step by step + image_patch = ImagePatch(image) + toy_patches = image_patch.find("toy") + # Question assumes only one toy patch + if len(toy_patches) == 0: + # If no toy is found, query the image directly + return image_patch.simple_query("What toy is wearing a shirt?") + for toy_patch in toy_patches: + is_wearing_shirt = (toy_patch.simple_query("Is the toy wearing a shirt?") == "yes") + if is_wearing_shirt: + return toy_patch.simple_query( + "What toy is wearing a shirt?") # crop would include the shirt so keep it in the query + # If no toy is wearing a shirt, pick the first toy + return toy_patches[0].simple_query("What toy is wearing a shirt?") + + +# Given an image: Who is the man staring at? +def execute_command(image) -> str: + # asks for the predicate of a relational verb (staring at), so ask directly + image_patch = ImagePatch(image) + return image_patch.simple_query("Who is the man staring at?") + + +# Given an image: Find more visible chair. +def execute_command(image) -> ImagePatch: + # Return the chair + image_patch = ImagePatch(image) + # Remember: return the chair + return image_patch.find("chair")[0] + + +# Given an image: Find lamp on the bottom. +def execute_command(image) -> ImagePatch: + # Return the lamp + image_patch = ImagePatch(image) + lamp_patches = image_patch.find("lamp") + lamp_patches.sort(key=lambda lamp: lamp.vertical_center) + # Remember: return the lamp + return lamp_patches[0] # Return the bottommost lamp + + +# Given a list of images: Does the pole that is near a building that is near a green sign and the pole that is near bushes that are near a green sign have the same material? +def execute_command(image_list) -> str: + material_1 = None + material_2 = None + for image in image_list: + image = ImagePatch(image) + # find the building + building_patches = image.find("building") + for building_patch in building_patches: + poles = building_patch.find("pole") + signs = building_patch.find("sign") + greensigns = [sign for sign in signs if sign.verify_property('sign', 'green')] + if len(poles) > 0 and len(greensigns) > 0: + material_1 = poles[0].simple_query("What is the material of the pole?") + # find the bush + bushes_patches = image.find("bushes") + for bushes_patch in bushes_patches: + poles = bushes_patch.find("pole") + signs = bushes_patch.find("sign") + greensigns = [sign for sign in signs if sign.verify_property('sign', 'green')] + if len(poles) > 0 and len(greensigns) > 0: + material_2 = poles[0].simple_query("What is the material of the pole?") + return bool_to_yesno(material_1 == material_2) + + +# Given an image: Find middle kid. +def execute_command(image) -> ImagePatch: + # Return the kid + image_patch = ImagePatch(image) + kid_patches = image_patch.find("kid") + if len(kid_patches) == 0: + kid_patches = [image_patch] + kid_patches.sort(key=lambda kid: kid.horizontal_center) + # Remember: return the kid + return kid_patches[len(kid_patches) // 2] # Return the middle kid + + +# Given an image: Is that blanket to the right of a pillow? +def execute_command(image) -> str: + image_patch = ImagePatch(image) + blanket_patches = image_patch.find("blanket") + # Question assumes only one blanket patch + if len(blanket_patches) == 0: + # If no blanket is found, query the image directly + return image_patch.simple_query("Is that blanket to the right of a pillow?") + for blanket_patch in blanket_patches: + pillow_patches = image_patch.find("pillow") + for pillow_patch in pillow_patches: + if pillow_patch.horizontal_center > blanket_patch.horizontal_center: + return "yes" + return "no" + + +# Given an image: How many people are there? +def execute_command(image) -> str: + image_patch = ImagePatch(image) + person_patches = image_patch.find("person") + return str(len(person_patches)) + + +# Given a list of images: Is the man that is wearing dark pants driving?. +def execute_command(image_list) -> str: + for image in image_list: + image = ImagePatch(image) + man_patches = image.find("man") + for man_patch in man_patches: + pants = man_patch.find("pants") + if len(pants) == 0: + continue + if pants[0].verify_property("pants", "dark"): + return man_patch.simple_query("Is this man driving?") + return ImagePatch(image_list[0]).simple_query("Is the man that is wearing dark pants driving?") + + +# Given an image: Is there a backpack to the right of the man? +def execute_command(image) -> str: + image_patch = ImagePatch(image) + man_patches = image_patch.find("man") + # Question assumes one man patch + if len(man_patches) == 0: + # If no man is found, query the image directly + return image_patch.simple_query("Is there a backpack to the right of the man?") + man_patch = man_patches[0] + backpack_patches = image_patch.find("backpack") + # Question assumes one backpack patch + if len(backpack_patches) == 0: + return "no" + for backpack_patch in backpack_patches: + if backpack_patch.horizontal_center > man_patch.horizontal_center: + return "yes" + return "no" + + +# Given a list of images: What is the pizza with red tomato on it on? +def execute_command(image_list) -> str: + for image in image_list: + image = ImagePatch(image) + pizza_patches = image.find("pizza") + for pizza_patch in pizza_patches: + tomato_patches = pizza_patch.find("tomato") + has_red_tomato = False + for tomato_patch in tomato_patches: + if tomato_patch.verify_property("tomato", "red"): + has_red_tomato = True + if has_red_tomato: + return pizza_patch.simple_query("What is the pizza on?") + return ImagePatch(image_list[0]).simple_query("What is the pizza with red tomato on it on?") + + +# Given an image: Find chair to the right near the couch. +def execute_command(image) -> ImagePatch: + # Return the chair + image_patch = ImagePatch(image) + chair_patches = image_patch.find("chair") + if len(chair_patches) == 0: + chair_patches = [image_patch] + elif len(chair_patches) == 1: + return chair_patches[0] + chair_patches_right = [c for c in chair_patches if c.horizontal_center > image_patch.horizontal_center] + couch_patches = image_patch.find("couch") + if len(couch_patches) == 0: + couch_patches = [image_patch] + couch_patch = couch_patches[0] + chair_patches_right.sort(key=lambda c: distance(c, couch_patch)) + chair_patch = chair_patches_right[0] + # Remember: return the chair + return chair_patch + + +# Given an image: Are there bagels or lemons? +def execute_command(image) -> str: + image_patch = ImagePatch(image) + is_bagel = image_patch.exists("bagel") + is_lemon = image_patch.exists("lemon") + return bool_to_yesno(is_bagel or is_lemon) + + +# Given an image: In which part is the bread, the bottom or the top? +def execute_command(image) -> str: + image_patch = ImagePatch(image) + bread_patches = image_patch.find("bread") + # Question assumes only one bread patch + if len(bread_patches) == 0: + # If no bread is found, query the image directly + return image_patch.simple_query("In which part is the bread, the bottom or the top?") + if bread_patches[0].vertical_center < image_patch.vertical_center: + return "bottom" + else: + return "top" + + +# Given an image: Find foo to bottom left. +def execute_command(image) -> ImagePatch: + # Return the foo + image_patch = ImagePatch(image) + foo_patches = image_patch.find("foo") + lowermost_coordinate = min([patch.vertical_center for patch in foo_patches]) + foo_patches_bottom = [patch for patch in foo_patches if patch.vertical_center - lowermost_coordinate < 100] + if len(foo_patches_bottom) == 0: + foo_patches_bottom = foo_patches + elif len(foo_patches_bottom) == 1: + return foo_patches_bottom[0] + foo_patches_bottom.sort(key=lambda foo: foo.horizontal_center) + foo_patch = foo_patches_bottom[0] + # Remember: return the foo + return foo_patch + + +# Given an image: Find number 17. +def execute_command(image) -> ImagePatch: + # Return the person + image_patch = ImagePatch(image) + person_patches = image_patch.find("person") + for patch in person_patches: + if patch.exists("17"): + return patch + # Remember: return the person + return person_patches[0] + + +# Given a list of images: Is the statement true? There is at least 1 image with a brown dog that is near a bicycle and is wearing a collar. +def execute_command(image_list) -> str: + for image in image_list: + image = ImagePatch(image) + dog_patches = image.find("dog") + for dog in dog_patches: + near_bicycle = dog.simple_query("Is the dog near a bicycle?") + wearing_collar = dog.simple_query("Is the dog wearing a collar?") + if near_bicycle == "yes" and wearing_collar == "yes": + return 'yes' + return 'no' + + +# Given an image: Find dog to the left of the post who is closest to girl wearing a shirt with text that says "I love you". +def execute_command(image) -> ImagePatch: + # Return the dog + image_patch = ImagePatch(image) + shirt_patches = image_patch.find("shirt") + if len(shirt_patches) == 0: + shirt_patches = [image_patch] + shirt_patch = best_image_match(list_patches=shirt_patches, content=["I love you shirt"]) + post_patches = image_patch.find("post") + post_patches.sort(key=lambda post: distance(post, shirt_patch)) + post_patch = post_patches[0] + dog_patches = image_patch.find("dog") + dogs_left_patch = [dog for dog in dog_patches if dog.left < post_patch.left] + if len(dogs_left_patch) == 0: + dogs_left_patch = dog_patches + dogs_left_patch.sort(key=lambda dog: distance(dog, post_patch)) + dog_patch = dogs_left_patch[0] + # Remember: return the dog + return dog_patch + + +# Given an image: Find balloon on the right and second from the bottom. +def execute_command(image) -> ImagePatch: + # Return the balloon + image_patch = ImagePatch(image) + balloon_patches = image_patch.find("balloon") + if len(balloon_patches) == 0: + balloon_patches = [image_patch] + elif len(balloon_patches) == 1: + return balloon_patches[0] + leftmost_coordinate = min([patch.horizontal_center for patch in balloon_patches]) + balloon_patches_right = [patch for patch in balloon_patches if patch.horizontal_center - leftmost_coordinate < 100] + if len(balloon_patches_right) == 0: + balloon_patches_right = balloon_patches + balloon_patches_right.sort(key=lambda p: p.vertical_center) + balloon_patch = balloon_patches_right[1] + # Remember: return the balloon + return balloon_patch + + +# Given an image: Find girl in white next to man in left. +def execute_command(image) -> ImagePatch: + # Return the girl + image_patch = ImagePatch(image) + girl_patches = image_patch.find("girl") + girl_in_white_patches = [g for g in girl_patches if g.verify_property("girl", "white clothing")] + if len(girl_in_white_patches) == 0: + girl_in_white_patches = girl_patches + man_patches = image_patch.find("man") + man_patches.sort(key=lambda man: man.horizontal_center) + leftmost_man = man_patches[0] # First from the left + girl_in_white_patches.sort(key=lambda girl: distance(girl, leftmost_man)) + girl_patch = girl_in_white_patches[0] + # Remember: return the girl + return girl_patch + + +# Given a list of images: Is the statement true? There is 1 table that is in front of woman that is wearing jacket. +def execute_command(image_list) -> str: + for image in image_list: + image = ImagePatch(image) + woman_patches = image.find("woman") + for woman in woman_patches: + if woman.simple_query("Is the woman wearing jacket?") == "yes": + tables = woman.find("table") + return bool_to_yesno(len(tables) == 1) + return 'no' + + +# Given an image: Find top left. +def execute_command(image) -> ImagePatch: + # Return the person + image_patch = ImagePatch(image) + # Figure out what thing the caption is referring to. We need a subject for every caption + persons = image_patch.find("person") + top_all_objects = max([obj.vertical_center for obj in persons]) + # Select objects that are close to the top + # We do this because the caption is asking first about vertical and then about horizontal + persons_top = [p for p in persons if top_all_objects - p.vertical_center < 100] + if len(persons_top) == 0: + persons_top = persons + # And after that, obtain the leftmost object among them + persons_top.sort(key=lambda obj: obj.horizontal_center) + person_leftmost = persons_top[0] + # Remember: return the person + return person_leftmost + + +# Given an image: What type of weather do you see in the photograph? +def execute_command(image) -> str: + image_patch = ImagePatch(image) + return image_patch.simple_query("What type of weather do you see in the photograph?") + + +# Given an image: How many orange life vests can be seen? +def execute_command(image) -> str: + image_patch = ImagePatch(image) + life_vest_patches = image_patch.find("life vest") + orange_life_vest_patches = [] + for life_vest_patch in life_vest_patches: + if life_vest_patch.verify_property('life vest', 'orange'): + orange_life_vest_patches.append(life_vest_patch) + return str(len(orange_life_vest_patches)) + + +# Given an image: What is behind the pole? +def execute_command(image) -> str: + image_patch = ImagePatch(image) + # contains a relation (around, next to, on, near, on top of, in front of, behind, etc), so ask directly + return image_patch.simple_query("What is behind the pole?") + + +# Given an image: Find second to top flower. +def execute_command(image) -> ImagePatch: + # Return the flower + image_patch = ImagePatch(image) + flower_patches = image_patch.find("flower") + flower_patches.sort(key=lambda flower: flower.vertical_center) + flower_patch = flower_patches[-2] + # Remember: return the flower + return flower_patch + + +# Given an image: Find back. +def execute_command(image) -> ImagePatch: + # Return the person + image_patch = ImagePatch(image) + person_patches = image_patch.find("person") + person_patches.sort(key=lambda person: person.compute_depth()) + person_patch = person_patches[-1] + # Remember: return the person + return person_patch + + +# Given an image: Find chair at the front. +def execute_command(image) -> ImagePatch: + # Return the chair + image_patch = ImagePatch(image) + chair_patches = image_patch.find("chair") + chair_patches.sort(key=lambda chair: chair.compute_depth()) + chair_patch = chair_patches[0] + # Remember: return the chair + return chair_patch + + +# Given an image: Find white and yellow pants. +def execute_command(image) -> ImagePatch: + # Return the person + image_patch = ImagePatch(image) + # Clothing always requires returning the person + person_patches = image_patch.find("person") + person_patch = best_image_match(person_patches, ["white pants", "yellow pants"]) + # Remember: return the person + return person_patch + + +# Given an image: Find cow facing the camera. +def execute_command(image) -> ImagePatch: + # Return the cow + image_patch = ImagePatch(image) + cow_patches = image_patch.find("cow") + if len(cow_patches) == 0: + cow_patches = [image_patch] + cow_patch = best_image_match(list_patches=cow_patches, content=["cow facing the camera"]) + # Remember: return the cow + return cow_patch + + +# Given a list of images: Is the statement true? There is 1 image that contains exactly 3 blue papers. +def execute_command(image_list) -> str: + image_cnt = 0 + for image in image_list: + image = ImagePatch(image) + paper_patches = image.find("paper") + blue_paper_patches = [] + for paper in paper_patches: + if paper.verify_property("paper", "blue"): + blue_paper_patches.append(paper) + if len(blue_paper_patches) == 3: + image_cnt += 1 + return bool_to_yesno(image_cnt == 1) + + +# Given an image: Find black car just under stop sign. +def execute_command(image) -> ImagePatch: + # Return the car + image_patch = ImagePatch(image) + stop_sign_patches = image_patch.find("stop sign") + if len(stop_sign_patches) == 0: + stop_sign_patches = [image_patch] + stop_sign_patch = stop_sign_patches[0] + car_patches = image_patch.find("black car") + car_under_stop = [] + for car in car_patches: + if car.upper < stop_sign_patch.upper: + car_under_stop.append(car) + # Find car that is closest to the stop sign + car_under_stop.sort(key=lambda car: car.vertical_center - stop_sign_patch.vertical_center) + # Remember: return the car + return car_under_stop[0] + + +# Given a list of images: Is there either a standing man that is holding a cell phone or a sitting man that is holding a cell phone? +def execute_command(image_list) -> str: + for image in image_list: + image = ImagePatch(image) + man_patches = image.find("man") + for man in man_patches: + holding_cell_phone = man.simple_query("Is this man holding a cell phone?") + if holding_cell_phone == "yes": + if man.simple_query("Is this man sitting?") == "yes": + return 'yes' + if man.simple_query("Is this man standing?") == "yes": + return 'yes' + return 'no' + + +# Given a list of images: How many people are running while looking at their cell phone? +def execute_command(image) -> str: + image_patch = ImagePatch(image) + people_patches = image_patch.find("person") + # Question assumes only one person patch + if len(people_patches) == 0: + # If no people are found, query the image directly + return image_patch.simple_query("How many people are running while looking at their cell phone?") + people_count = 0 + for person_patch in people_patches: + # Verify two conditions: (1) running (2) looking at cell phone + if person_patch.simple_query("Is the person running?") == "yes": + if person_patch.simple_query("Is the person looking at cell phone?") == "yes": + people_count += 1 + return str(people_count) + + +# Given a list of images: Does the car that is on a highway and the car that is on a street have the same color? +def execute_command(image_list) -> str: + color_1 = None + color_2 = None + for image in image_list: + image = ImagePatch(image) + car_patches = image.find("car") + for car_patch in car_patches: + if car_patch.simple_query("Is the car on the highway?") == "yes": + color_1 = car_patch.simple_query("What is the color of the car?") + elif car_patch.simple_query("Is the car on a street?") == "yes": + color_2 = car_patch.simple_query("What is the color of the car?") + return bool_to_yesno(color_1 == color_2) + + +# Given a list of images: Is the statement true? There are 3 magazine that are on table. +def execute_command(image_list) -> str: + count = 0 + for image in image_list: + image = ImagePatch(image) + magazine_patches = image.find("magazine") + for magazine_patch in magazine_patches: + on_table = magazine_patch.simple_query("Is the magazine on a table?") + if on_table == "yes": + count += 1 + return bool_to_yesno(count == 3) + + +# INSERT_QUERY_HERE \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..62128df953e952220505e6a5ba67b310c03ac825 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,258 @@ +absl-py +accelerate +aiohttp +aiosignal +annotated-types +antlr4-python3-runtime +anyio +appdirs +argon2-cffi +argon2-cffi-bindings +arrow +asttokens +astunparse +async-lru +async-timeout +attrs +Babel +backcall +backoff +beautifulsoup4 +bitsandbytes==0.39.0 +bleach +cachetools +certifi +cffi +charset-normalizer +cityscapesScripts +click +cloudpickle +cmake +coloredlogs +comm +contourpy +cycler +datasets +debugpy +decorator +decord +# deepspeed +defusedxml +packaging +dill +diskcache +distro +diversipy +docker-pycreds +docopt +einops +exceptiongroup +executing +fastapi +fastjsonschema +filelock +fire +flash-attn +flatbuffers +fonttools +fqdn +frozenlist +fsspec +ftfy +gast +gitdb +GitPython +google-auth +google-auth-oauthlib +google-pasta +grpcio +h11 +h5py +hjson +httpcore +httptools +httpx +huggingface-hub +humanfriendly +idna +ImageHash +inflect +interegular +ipykernel +ipython +ipywidgets +isoduration +jedi +Jinja2 +joblib +json5 +jsonpointer +jsonschema +jsonschema-specifications +jupyter +jupyter-console +jupyter-events +jupyter-lsp +jupyter_client +jupyter_core +jupyter_server +jupyter_server_terminals +jupyterlab +jupyterlab_pygments +jupyterlab_server +jupyterlab_widgets +keras +kiwisolver +kornia +lark +libclang +llvmlite +Markdown +markdown-it-py +MarkupSafe +matplotlib +matplotlib-inline +mdurl +mistune +more-itertools +mpmath +msgpack +multidict +multiprocess +nbclient +nbconvert +nbformat +nest-asyncio +networkx +ninja +nltk +notebook +notebook_shim +num2words +numba +numpy==1.23.5 +oauthlib +omegaconf +openai +opencv-python-headless +opt-einsum +outlines +overrides +packaging +pandas +pandocfilters +parso +pathtools +peft +pexpect +pickleshare +Pillow +platformdirs +prettytable +progressbar +prometheus_client +prompt-toolkit +protobuf +psutil +ptyprocess +pure-eval +py-cpuinfo +pyarrow +pyarrow-hotfix +pyasn1 +pyasn1-modules +pybind11 +pycocotools +pycparser +pydantic +pydantic_core +Pygments +pynvml +pyparsing +pyquaternion +PySnooper +python-dateutil +python-dotenv +python-json-logger +pytz +PyWavelets +PyYAML +pyzmq +qd +qtconsole +QtPy +ray +referencing +regex +requests +requests-oauthlib +responses +rfc3339-validator +rfc3986-validator +rich +rpds-py +rsa +safetensors +scikit-learn +scipy +seaborn +Send2Trash +sentencepiece +sentry-sdk +setproctitle +six +smmap +sniffio +soupsieve +stack-data +starlette +sympy +# tensorboard +# tensorboard-data-server +# tensorboard-plugin-wit +# tensorboardX +# tensorflow +# tensorflow-estimator +# tensorflow-io-gcs-filesystem +termcolor +terminado +threadpoolctl +tiktoken +timm +tinycss2 +tokenizers +tomli +# torch +# torchaudio +# torchvision +tornado +tqdm +traitlets +transformers==4.42.4 +triton +typeguard +types-python-dateutil +typing +typing_extensions +uri-template +urllib3 +uvicorn +uvloop +# vllm +wandb +watchfiles +wcwidth +webcolors +webencodings +websocket-client +websockets +Werkzeug +widgetsnbextension +word2number +wrapt +# xformers +xxhash +yacs +yarl +gradio +# huggingface_hub diff --git a/trace_exec.py b/trace_exec.py new file mode 100644 index 0000000000000000000000000000000000000000..f055f398f738c15b93667e71ffbc905db32dd0d1 --- /dev/null +++ b/trace_exec.py @@ -0,0 +1,143 @@ +import ast +import importlib +import io +import os +import re +import string +import time +from functools import partial +from typing import List + +import pysnooper + +FUNCTION_HEAD = "def execute_command({input_type}) -> {output_type}:" +EXEC_FUNCTION_HEAD = 'def execute_command({input_type}, possible_answers, query, ImagePatch, VideoSegment,' \ + ' llm_query, bool_to_yesno, distance, best_image_match):' + + +class CompileTimeError: + pass + + +class ProgramRuntimeError: + pass + + +def process_trace(text, function_head, execution_function_head): + def remove_indent(lines): + n_space = 0 + for i, c in enumerate(lines[0]): + if c == ' ': + n_space += 1 + else: + break + return [line[n_space:] if line[0] == ' ' else line for line in lines] + + def remove_pre_context(lines: List[str]): # lol, just a random use of List + for i in range(len(lines) - 1, -1, -1): + line = lines[i] + if execution_function_head in line: + # assert "call" in line # TODO: further double-check? + content = [line.replace(execution_function_head, function_head)] + lines[i + 1:] + if line[0] == ' ': + return remove_indent(content) + else: + return content + return [] + + def remove_post_context(lines): + for i, line in enumerate(lines): + if line.startswith("Source path:") and line.endswith(__file__): + return lines[:i] + elif line.startswith("Elapsed time"): + return lines[:i] + return lines + + def remove_timestamp(lines): + ret = [] + for line in lines: + if len(line) > 0 and line[0] in string.digits: + line = line[16:] # remove timestamp + ret.append(line) + return ret + + def remove_tensor(line): + return re.sub(r"tensor\(\[\[\[.*?\]\]\]\)", "tensor([[[...]]])", line) + + lines = text.splitlines() + lines = remove_pre_context(lines) + lines = remove_post_context(lines) + lines = remove_timestamp(lines) + lines = [remove_tensor(line) for line in lines] + + return '\n'.join(lines) + + +cnt = 0 + + +def run_program_with_trace(code, image, input_type_, output_type_): + from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno + + function_head = FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_) + execution_function_head = EXEC_FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_) + + code = str(code) + if code.startswith("\ndef"): + code = code[1:] # TODO: just a temporary fix + + if code.startswith('def'): + if code.startswith(function_head): + code = code.replace(function_head, '') + else: + print("--- Code with invalid format\n") + print(code) + code = execution_function_head + code + try: + code = ast.unparse(ast.parse(code)) + except: + return None, CompileTimeError(), None + + global cnt + cnt += 1 + name = f'x{cnt}' + with open(f'{name}.py', 'w') as f: + f.write(code) + + for _ in range(20): + try: + x = importlib.import_module(name) + except ModuleNotFoundError: + print("Errrr, import error. Wait a bit while.") + time.sleep(60) # I have no idea why it sometimes fails. Probably file system error + except Exception as e: + print("Import has error:", e) + break + else: + break + + queues = [None, None] + + image_patch_partial = partial(ImagePatch, queues=queues) + video_segment_partial = None + llm_query_partial = partial(llm_query, queues=queues) + + # signal.signal(signal.SIGALRM, handler) # unfortunately doesn't work + # signal.alarm(60 * 20) # timeout = 10min, just in case while True + with io.StringIO() as f: + with pysnooper.snoop(output=f, color=False, depth=2, max_variable_length=1000): + result = None + error = None + try: + result = x.execute_command(image, None, '', image_patch_partial, video_segment_partial, + llm_query_partial, bool_to_yesno, distance, best_image_match) + except: + error = ProgramRuntimeError() + # finally: + # signal.alarm(0) + os.remove(f'{name}.py') + f.seek(0) + traced = f.read(100000) + traced_processed = process_trace(traced, function_head, execution_function_head) + + return result, error, traced_processed diff --git a/useful_lists/possible_options.json b/useful_lists/possible_options.json new file mode 100644 index 0000000000000000000000000000000000000000..92c2b5e3c4f7e8cc30bb1f7881a216af9bb72df3 --- /dev/null +++ b/useful_lists/possible_options.json @@ -0,0 +1 @@ +{"colors": ["red", "orange", "yellow", "green", "blue", "purple", "pink", "brown", "black", "white"], "materials": ["wood", "metal", "plastic", "glass", "fabric", "leather", "paper", "ceramic ", "rubber ", "stone"], "actions": ["flying", "splashing", "tossing", "riding", "standing", "hugging", "hanging", "breaking", "having meeting", "pulling", "decorating", "facing", "preparing", "pouring", "pointing", "winding", "using", "petting", "licking", "carrying", "skiing", "chasing", "wedding", "tying", "crossing", "bending", "feeding", "laughing", "driving", "herding", "skating", "making", "observing", "exiting", "folding", "glowing", "sniffing", "grazing", "helping", "sitting", "running", "towing", "reading", "buying", "posing", "slicing", "entering", "jumping", "waving", "cooking", "catching", "dragging", "playing", "lying", "serving", "biting", "chewing", "boarding", "washing", "selling", "squatting", "pushing", "melting", "walking", "holding", "hitting", "resting", "enclosing", "smiling", "balding", "approaching", "swinging", "watching", "kneeling", "crashing", "staring", "grabbing", "moving", "kissing", "sliding", "floating", "smelling", "spraying", "covering", "rippling", "dangling", "kicking", "throwing", "swimming", "blowing", "opening", "brushing", "smoking", "leading", "adjusting", "talking", "peeling", "eating", "climbing", "guiding", "blooming", "snowboarding", "photographing", "spinning", "leaving", "waiting", "drinking", "burning", "cutting", "following", "crouching", "touching", "surfing", "batting", "cleaning", "wearing", "sleeping", "skateboarding", "shining", "surrounding"], "relations": ["splashing", "looking toward", "sleeping on", "walking across", "pulling", "decorating", "topped with", "picking up", "printed on", "preparing", "displayed on", "resting on", "lying on top of", "licking", "carrying", "chasing", "getting on", "growing on", "driving down", "to the left of", "wrapped in", "observing", "traveling on", "standing against", "covered by", "sniffing", "hidden by", "hanging over", "reading", "jumping off", "underneath", "dragging", "with", "flying over", "taller than", "by", "growing in", "eating in", "on the back of", "shorter than", "drinking from", "skiing on", "running through", "playing at", "hitting", "sitting near", "in front of", "trying to catch", "going into", "parked behind", "drawn on", "going through", "walking by", "on the front of", "covered in", "followed by", "hanging in", "kicking", "throwing", "contain", "opening", "brushing", "sitting under", "behind", "flying above", "running on", "flying in", "standing on", "sitting atop", "lying in", "tied to", "longer than", "at", "in between", "eating", "inside", "floating on", "standing in front of", "larger than", "following", "touching", "wearing", "standing near", "running in", "on the bottom of", "walking through", "looking in", "riding", "jumping over", "coming down", "parked on", "walking up", "facing", "playing on", "sitting on top of", "hang from", "smaller than", "around", "sitting inside", "walking toward", "skating on", "standing beside", "looking down at", "tying", "crossing", "walking in", "between", "on the edge of", "on top of", "sitting next to", "hanging out of", "displayed in", "on", "playing in", "growing along", "playing", "biting", "sitting in front of", "boarding", "standing around", "near", "pushing", "parked by", "sprinkled on", "wading in", "enclosing", "staring at", "on the side of", "swinging", "moving", "kissing", "walking to", "leaning against", "reflected on", "smoking", "parked in", "riding on", "lying next to", "reaching for", "walking along", "parked near", "parked at", "riding in", "photographing", "leaving", "cutting", "eating from", "sitting with", "cleaning", "scattered on", "hanging from", "parked beside", "beside", "surrounding", "chained to", "cooked in", "seen through", "talking on", "hugging", "filled with", "draped over", "parked alongside", "of", "worn around", "coming from", "walking next to", "waiting for", "petting", "hanging on", "next to", "walking past", "jumping on", "perched on", "higher than", "parked next to", "attached to", "typing on", "growing by", "feeding", "growing near", "kept in", "standing behind", "sitting on", "standing under", "herding", "making", "walking with", "mounted to", "exiting", "running across", "sitting at", "walking on", "helping", "looking over", "sitting in", "pushed by", "falling off", "decorated by", "buying", "cooking", "catching", "serving", "reflecting in", "chewing", "washing", "traveling down", "looking into", "stuck on", "selling", "looking through", "working in", "holding", "hanging above", "standing by", "parked along", "grabbing", "above", "smelling", "covering", "reflected in", "going down", "walking near", "hung on", "growing next to", "growing from", "lying inside", "leading", "looking out", "sleeping in", "served on", "looking at", "stuck in", "jumping in", "pulled by", "connected to", "growing behind", "walking towards", "climbing", "bigger than", "holding onto", "mounted on", "full of", "sitting behind", "flying", "tossing", "driving on", "mixed with", "pouring", "swimming in", "using", "skiing down", "coming out of", "walking behind", "skiing in", "driving", "grazing on", "parked in front of", "playing with", "blowing out", "walking down", "below", "towing", "slicing", "entering", "sitting beside", "balancing on", "sitting by", "surrounded by", "covered with", "talking to", "wrapped around", "leaning over", "shining through", "smiling at", "working on", "standing next to", "tied around", "approaching", "sitting around", "watching", "beneath", "floating in", "piled on", "flying through", "on the other side of", "stacked on", "hanging off", "eating at", "under", "close to", "about to hit", "adjusting", "standing at", "posing with", "painted on", "decorated with", "guiding", "worn on", "standing in", "in", "drinking", "sewn on", "grazing in", "plugged into", "pointing at", "lying on", "leaning on", "standing on top of", "walking into", "to the right of"], "locations": ["outdoors", "indoors"], "weathers": ["stormy", "overcast", "sunny", "cloudy", "foggy", "partly cloudy", "dark", "rainy", "cloudless", "clear"], "attributes": ["splashing", "brick", "tied", "overcast", "sliced", "massive", "oversized", "comfortable", "military", "shaggy", "large", "hardwood", "rounded", "modern", "upholstered", "cracked", "rimmed", "dry", "fake", "commercial", "wild", "parked", "decorative", "rubber", "male", "indoors", "upside down", "wooded", "speckled", "female", "power", "soft", "shadowed", "crumbled", "disposable", "scarce", "recessed", "wii", "beautiful", "warm", "hairy", "crowded", "light blue", "seasoned", "floppy", "feathered", "halved", "reading", "foreign", "padded", "posing", "jumping", "jagged", "tasty", "sprinkled", "sleepy", "sealed", "sad", "short", "full", "dirty", "spiral", "sweet", "printed", "curved", "cardboard", "narrow", "squatting", "melted", "used", "pine", "copper", "melting", "walking", "lined", "roman", "hitting", "lighted", "strawberry", "railroad", "resting", "rippled", "textured", "frozen", "yellow", "diamond", "huge", "barren", "computer", "crashing", "connected", "bathroom", "granite", "shaved", "fat", "spraying", "blond", "flat", "dangling", "baked", "curly", "greasy", "fuzzy", "hazy", "support", "manicured", "swimming", "sleeveless", "unripe", "blowing", "arched", "scrambled", "snow", "iced", "peeling", "snowy", "blue", "unpeeled", "steamed", "kitchen", "scattered", "banana", "curious", "eating", "damaged", "suspended", "ivory", "closed", "blooming", "purple", "winter", "old", "creamy", "tail", "performing trick", "illuminated", "nike", "waiting", "horizontal", "down", "batting", "roasted", "industrial", "delicious", "browned", "abundant", "bamboo", "cloudy", "smooth", "light", "skateboarding", "wool", "framed", "chalk", "fine", "rolled", "navy", "potted", "riding", "standing", "water", "denim", "brass", "pastel", "blank", "little", "filled", "real", "patchy", "broken", "breaking", "grated", "unpaved", "long", "rough", "dark blue", "sparse", "steep", "breakable", "cloudless", "floral", "chain-link", "bending", "bare", "cloth", "ocean", "barbed", "oblong", "wide", "office", "tangled", "barefoot", "baseball", "diced", "ruffled", "fluffy", "exterior", "paper", "raw", "deciduous", "crooked", "glowing", "rectangular", "rainbow colored", "gray", "mesh", "vibrant", "on", "striped", "tiny", "thick", "tilted", "rocky", "vintage", "crouched", "light colored", "cordless", "inflatable", "angled", "turned", "young", "playing", "soap", "capital", "wrist", "cotton", "cut", "toilet", "half full", "adult", "dusty", "white", "short sleeved", "raised", "dull", "baby", "shingled", "chipped", "funny", "mowed", "woven", "polar", "braided", "swinging", "public", "beaded", "heavy", "powerful", "stone", "urban", "carved", "clay", "coarse", "shallow", "wired", "caucasian", "floating", "octagonal", "paneled", "oriental", "docked", "grilled", "rippling", "groomed", "khaki", "opaque", "glass", "unhappy", "ski", "furry", "fire", "rotten", "cream colored", "clean", "talking", "displayed", "black and white", "eaten", "glossy", "skinny", "adidas", "packed", "loose", "peeled", "muddy", "straw", "black", "gas", "plain", "abandoned", "choppy", "small", "traffic", "misty", "overgrown", "telephone", "fenced", "angry", "chocolate", "fancy", "wispy", "knotted", "surfing", "sloped", "aluminum", "giant", "faded", "slanted", "square", "paved", "pepper", "gloomy", "carpeted", "shining", "crispy", "vertical", "sculpted", "inflated", "leafy", "packaged", "ceramic", "sheer", "vast", "hanging", "rusty", "orange", "tiled", "triangular", "having meeting", "electric", "fried", "fallen", "asphalt", "light brown", "pointing", "winding", "soda", "cylindrical", "healthy", "empty", "coffee", "still", "chinese", "old fashioned", "busy", "shirtless", "bushy", "cobblestone", "dark brown", "skiing", "wedding", "digital", "artificial", "high", "simple", "checkered", "strong", "chopped", "laughing", "antique", "translucent", "rugged", "evergreen", "gold", "pulled back", "crossed", "partly cloudy", "typical", "upper", "attached", "grazing", "sitting", "running", "park", "homemade", "quilted", "mounted", "tropical", "spread", "chrome", "cooking", "rock", "hollow", "wavy", "directional", "dried", "ugly", "immature", "marble", "professional", "muscular", "apple", "open", "bunched", "cushioned", "gloved", "bronze", "clear", "staring", "pointy", "iron", "wire", "chubby", "messy", "deep", "fresh", "miniature", "safety", "wine", "shredded", "leather", "uncooked", "christmas", "piled", "birthday", "soccer", "lush", "up", "torn", "bent", "patched", "rustic", "elevated", "assorted", "american", "dark colored", "edged", "weathered", "street", "stainless steel", "stuffed", "snowboarding", "shaded", "teal", "garbage", "tall", "dense", "outdoor", "protective", "stained", "flowered", "sleeping", "shut", "uneven", "elderly", "powdered", "silk", "tennis", "puffy", "flying", "jeans", "mature", "dotted", "leafless", "gravel", "silver", "transparent", "frosted", "neat", "drawn", "regular", "steel", "beer", "unhealthy", "mixed", "tan", "maroon", "portable", "toy", "corded", "overhead", "unlit", "polished", "analog", "looking up", "ornamental", "thin", "complete", "brunette", "hard", "discolored", "tomato", "cooked", "alert", "folded", "porcelain", "perched", "foggy", "designed", "formal", "reflected", "rainy", "outstretched", "glazed", "trimmed", "sturdy", "calm", "wireless", "driving", "painted", "long sleeved", "unoccupied", "skating", "athletic", "covered", "dark", "french", "vacant", "wicker", "bright", "folding", "spiky", "boiled", "written", "plastic", "ridged", "pale", "curvy", "red", "waving", "trash", "foamy", "metal", "lying", "styrofoam", "pizza", "reflective", "concrete", "twisted", "sunlit", "wrinkled", "domed", "polo", "decorated", "fluorescent", "electronic", "pretty", "lower", "incomplete", "worn", "ripe", "tin", "off", "sandy", "murky", "ornate", "toasted", "grouped", "shaped", "asian", "tinted", "cluttered", "smiling", "uncomfortable", "balding", "blurry", "wrinkly", "kneeling", "crumpled", "bald", "lit", "happy", "sharp", "forested", "sliding", "lace", "plush", "beige", "pink", "stormy", "patterned", "elongated", "wet", "palm", "crisp", "green", "vinyl", "shiny", "straight", "neon", "intricate", "tight", "brown", "abstract", "new", "collared", "double decker", "cheese", "calico", "colorful", "tabby", "made", "handmade", "irregular", "burnt", "stacked", "low", "vanilla", "looking down", "juicy", "license", "spinning", "burning", "drinking", "crouching", "clumped", "wood", "sunny", "plaid", "crystal", "wrapped", "outdoors", "curled", "crusty", "round", "knit", "grassy"], "objects": ["antenna", "store", "food truck", "room", "ornament", "coffee shop", "bags", "statue", "jars", "cup", "baskets", "dining table", "restaurant", "goat", "swan", "vendor", "stadium", "arms", "fork", "ball", "radio", "rabbit", "stove", "rooftop", "horns", "wine glass", "sofa", "robot", "racket", "pork", "ostriches", "hotdog bun", "clocks", "tree leaves", "sauce", "pretzels", "end table", "shopping bag", "waitress", "computer mouse", "sunglasses", "shore", "paw", "mannequins", "yard", "soap dish", "vase", "cinnamon roll", "painting", "buffet", "bird cage", "oreo", "cages", "heart", "food trucks", "turkey", "railroad", "vines", "cappuccino", "cookie dough", "roast beef", "girls", "pancakes", "spectator", "people", "sign post", "drawers", "medicine cabinet", "trailer", "farmer", "chocolate chips", "snow", "menus", "pumpkins", "kitchen", "island", "sprinkles", "sheep", "fries", "backpack", "minute hand", "necklaces", "bar stool", "cyclist", "onion rings", "elephants", "ornaments", "meatballs", "roadside", "flamingo", "supermarket", "cranberry", "salmon", "fan", "packet", "rhinos", "gas pump", "elevator", "cage", "jar", "step", "cord", "leaves", "sugar packets", "balcony", "sharks", "alligator", "ipod", "urinal", "daughter", "lego", "dinner", "pickles", "air conditioner", "fudge", "clothes", "cell phone", "alien", "office", "mozzarella", "toilet tank", "seat", "leggings", "shopper", "buildings", "cauliflower", "vases", "visitors", "bagel", "action figure", "soap", "omelette", "canopy", "wrist", "rubber duck", "officer", "suitcases", "seagulls", "knife block", "ladder", "honey", "beads", "vehicles", "closet", "carriage", "balloon", "hedge", "aircraft", "toolbox", "tea kettle", "collar", "calf", "donkey", "crate", "trash can", "ring", "milk carton", "sweet potato", "suit", "airplanes", "waiter", "waste basket", "squirrel", "feathers", "potatoes", "trays", "desk", "traffic light", "wallpaper", "lilies", "hand", "grinder", "utensil holder", "sweet potatoes", "chinese food", "mannequin", "steam", "dome", "menu", "hats", "pear", "toothpaste", "television", "roll", "mustard bottle", "tea pot", "plain", "cereal box", "cat food", "doll", "mattresses", "snow pants", "turbine", "merchandise", "wall", "whisk", "egg yolk", "taxis", "cream cheese", "bakery", "canisters", "riding boots", "device", "bat", "peanut butter", "chili", "sweatshirt", "alcohol", "remote controls", "cable", "pastry", "snakes", "spots", "comb", "ravioli", "lamp shade", "pliers", "buns", "swans", "animals", "town", "palace", "bear", "jockey", "water glass", "toothbrush", "batter", "windows", "vest", "cockpit", "mushroom", "mailbox", "heater", "muffin", "toilet brush", "ham", "avocados", "park", "fish", "daisy", "doorway", "scrub brush", "beach", "skater", "avocado", "countertop", "outlet", "citrus", "tourists", "bird house", "pizza pie", "seal", "vitamins", "wetsuit", "blood", "decorations", "toilet seat", "hook", "coin", "apartment building", "apple", "goose", "trunk", "bracelets", "bone", "toys", "workers", "parachute", "whipped cream", "castles", "cow", "kimono", "mayonnaise", "pizza shop", "pandas", "bedding", "pumpkin", "car", "fire trucks", "elephant", "ear buds", "dessert", "cemetery", "wardrobe", "post", "olive", "juice box", "peacock", "van", "sticker", "computer monitor", "foot", "coffee beans", "fur", "briefcase", "napkins", "sack", "mickey mouse", "utensil", "icing", "coffee bean", "hill", "horse hoof", "suitcase", "salt", "counter", "street", "cucumbers", "blossom", "bricks", "bacon", "outfits", "finger", "girl", "sword", "surfboards", "candy", "head", "beet", "price tag", "fire extinguisher", "apple logo", "stickers", "scaffolding", "stuffed animals", "pie", "stop sign", "bandana", "stick", "telephone pole", "elbow pad", "spice", "giraffes", "vacuum", "street light", "dragons", "doors", "bagels", "waterfall", "barrier", "wagon", "patio", "coffee pot", "anchovies", "tennis balls", "burger", "home plate", "balls", "potato", "carrots", "rocks", "characters", "thumb", "hamburgers", "pizza box", "cars", "parrot", "goggles", "propeller", "match", "power lines", "diaper", "pine trees", "peppers", "grass", "panda bears", "eagle", "mashed potatoes", "trash", "knives", "almonds", "cooking oil", "kitten", "legs", "wheelchair", "cranberries", "towel dispenser", "boxes", "nutella", "bread", "stuffed dog", "helmet", "cameras", "scooter", "drapes", "dream catcher", "microwave oven", "onions", "costume", "sock", "soft drink", "train", "flatbread", "tables", "swamp", "kiwi", "toppings", "coffee table", "beans", "teddy bears", "boar", "aquarium", "fans", "bed", "egg", "melon", "outfit", "cheese", "cooking pot", "christmas light", "brush", "soldiers", "baseball mitt", "stew", "candle holder", "hurdle", "paintings", "chains", "whale", "bottles", "ducks", "banana bunch", "benches", "dog", "smoke", "wristwatch", "snoopy", "skyscrapers", "hilltop", "train car", "zucchini", "tiles", "crowd", "oak tree", "lady", "pizza slices", "champagne", "sidewalk", "child", "water bottles", "trunks", "hand soap", "bush", "hour hand", "banana bunches", "customer", "lighthouse", "dishwasher", "hangar", "paper towel", "weeds", "pavement", "dip", "bread loaf", "dolls", "hat", "knee pad", "sailboats", "crates", "ropes", "hose", "book", "ice cube", "lily", "pig", "soap dispenser", "cart", "pockets", "bird", "blind", "controllers", "bicycle", "canister", "blanket", "basket", "skillet", "kitchen towel", "harbor", "mixer", "dirt", "roadway", "pomegranate", "monitors", "tape", "hot dog", "pasta", "mustache", "bicycles", "glasses", "pens", "grapefruit", "vegetable", "garden", "kiwis", "pouch", "pitcher", "burrito", "baseball bats", "dress shirt", "lettuce", "helmets", "hippo", "comforter", "branch", "cupcake", "baked good", "lips", "shoe laces", "paper container", "light switch", "artichokes", "spider", "suits", "sandals", "bartender", "tissue box", "skin", "banana", "lobby", "keyboard", "water bottle", "cat", "calculator", "wheel", "steps", "bandage", "log", "wool", "faucet", "spoon", "pocket", "donut", "sneakers", "football", "farmers", "seaweed", "pasta salad", "athletic shoe", "hospital", "potato chips", "dvd player", "baseball players", "pepper shaker", "mall", "oranges", "pretzel", "figurine", "nut", "ocean", "rackets", "game", "sandwich", "cowboy hat", "baseball", "curtains", "sky", "antelope", "washing machine", "paper", "eggplant", "snack", "shirt", "wii controllers", "tag", "tires", "bears", "artichoke", "cabin", "ground", "sporting equipment", "wolves", "stores", "toothbrushes", "ice", "chef hat", "napkin", "stroller", "papers", "face", "ginger", "visitor", "cutting board", "men", "pocket watch", "dry-erase board", "bananas", "clouds", "truck", "vending machine", "mountain peak", "wine bottles", "figurines", "lab coat", "zebras", "door frame", "stone", "paint brush", "cables", "hammer", "tunnel", "baker", "shaving cream", "ottoman", "tray", "dragon", "antelopes", "goal", "meat", "flag", "pen", "fire", "pedestrian", "noodles", "dough", "casserole", "pita", "bugs", "heel", "picnic", "antennas", "towels", "chandelier", "straw", "hibiscus", "giraffe", "office chair", "chairs", "tags", "sheets", "tablecloth", "broccoli", "children", "chocolate", "taxi", "snow boots", "olive oil", "headband", "earring", "undershirt", "sponge", "feta cheese", "pudding", "powder", "marina", "egg roll", "stuffed bears", "performer", "tissues", "father", "cathedral", "orange", "dresser", "wings", "cereal", "umbrellas", "herb", "ovens", "drawer", "arrow", "coffee", "orchids", "fishing pole", "tangerine", "bull", "berries", "bubble", "forest", "mound", "cracker", "lipstick", "tissue", "egg white", "liquor", "gadget", "mug", "lamb", "hotel room", "poodle", "laptops", "baseball bat", "grapes", "vehicle", "snail", "wine bottle", "arm", "watch", "village", "cinnamon", "lemon", "table lamp", "goats", "instrument", "snowboards", "ostrich", "egg shell", "drink", "kites", "cherries", "mud", "player", "gift", "masks", "paper towels", "cane", "trains", "spear", "cash register", "ramekin", "paint", "sandwiches", "grill", "figure", "bedroom", "vinegar", "pan", "stairs", "stuffed animal", "snacks", "neck", "butter knife", "cupboards", "light bulbs", "street sign", "spices", "surfer", "flour", "bowl", "keypad", "sandal", "cooker", "potato salad", "dog food", "baking pan", "ladle", "branches", "package", "macaroni", "platter", "dugout", "projector", "pea", "wig", "castle", "bottle cap", "stir fry", "camel", "pizza cutter", "beer", "pillars", "marker", "walls", "xbox controller", "mixing bowl", "salad", "pastries", "storage box", "biscuit", "sconce", "tomato", "cd", "lobster", "apples", "spray bottle", "peas", "wing", "street lights", "mangoes", "pears", "entertainment center", "apron", "engineer", "dressing", "guitar", "carrot", "spinach", "pecan", "cherry", "can opener", "mexican food", "food container", "groceries", "fruit", "zebra", "houses", "student", "field", "stage", "umbrella", "radiator", "soda can", "fisherman", "cigarettes", "monster", "character", "headphones", "salt shaker", "tuna", "swimsuit", "kiosk", "shampoo bottle", "cell phones", "lunch box", "penguins", "window", "temple", "skateboard", "oatmeal", "flowers", "desk lamp", "piano", "broom", "speaker", "container", "onion", "worker", "dinosaur", "croissant", "soccer ball", "runway", "catcher", "mother", "dryer", "parking meter", "light fixture", "trucks", "coconut", "bucket", "balloons", "mouth", "uniforms", "console", "living room", "apartment", "headboard", "egg carton", "nose", "video game", "lock", "guests", "guacamole", "shirts", "bracelet", "zoo", "coffee mug", "twigs", "angry bird", "bikes", "frame", "burner", "platform", "knee pads", "christmas lights", "tiger", "bouquet", "money", "dresses", "milk", "utensils", "touchpad", "wires", "monitor", "flames", "cakes", "dumpster", "blenders", "teeth", "school bus", "cafeteria", "beach umbrella", "manhole cover", "photographer", "parking sign", "armchair", "tongs", "mustard", "gorilla", "label", "bread box", "bus", "coke", "minivan", "walnut", "soccer player", "pizza pan", "pasture", "computers", "curtain", "traffic sign", "luggage cart", "shield", "scarf", "hair", "peaches", "snow shoes", "penguin", "cotton candy", "parking lot", "shelter", "accessory", "cookie jar", "garage door", "pedestrians", "pancake", "plate", "chalkboard", "tools", "herd", "placemat", "candies", "coffee maker", "skier", "strawberry", "twig", "tongue", "lambs", "computer", "tree", "cups", "rice cooker", "cooler", "bathroom", "mask", "rifle", "rhino", "bunny", "tree branch", "shoe", "blankets", "strawberries", "owls", "tire", "toilet lid", "pizza boxes", "asparagus", "oven", "cows", "dogs", "bus stop", "boat", "son", "carts", "ketchup bottle", "shark", "tail", "mane", "star", "wildflower", "handbag", "cans", "side table", "train tracks", "train station", "shower door", "donkeys", "cigarette", "tool", "display", "road", "couple", "cats", "armor", "nightstand", "cactus", "bookshelf", "shop", "passengers", "buses", "soda cans", "granola", "lid", "toy car", "spray can", "books", "pepperoni", "vegetables", "dvd players", "produce", "snow flakes", "chess piece", "steering wheel", "oil", "athlete", "sausages", "beds", "poster", "gifts", "lawn", "suv", "swimming pool", "sea", "powdered sugar", "floor lamp", "parrots", "highway", "raincoat", "palm trees", "cabinets", "skateboards", "remote control", "toaster", "toilet", "orchard", "bison", "mountain side", "boots", "tomatoes", "cheeseburger", "sun", "shoppers", "classroom", "ice-cream cone", "blossoms", "map", "containers", "lime", "beard", "cap", "wolf", "panda bear", "milkshake", "cliff", "plier", "mice", "guy", "chopstick", "jumpsuit", "syrup", "pizzas", "glass", "frog", "ski", "paddle", "coats", "fence", "skirt", "eagles", "garage", "woman", "pipe", "nest", "steak", "paws", "lemonade", "skate park", "raspberry", "clock tower", "knife", "loaf", "magnet", "pistachio", "printer", "bag", "silverware", "audience", "cheesecake", "bikini", "gate", "wallet", "sunflowers", "wii game", "blueberries", "blueberry", "logo", "hotel", "rope", "pond", "panda", "underwear", "controller", "bug", "american flag", "sticks", "deck", "restroom", "dvds", "shopping cart", "pipes", "pineapples", "towel", "basil", "motorcycle", "wedding", "lions", "fire hydrant", "nuts", "bench", "sheet", "tent", "salad dressing", "sand", "game controller", "jewelry", "hot dogs", "swimmer", "surfboard", "pigeon", "skyscraper", "chicken", "horse", "umpire", "entrance", "bushes", "sail", "moose", "homes", "plant", "logs", "appliance", "name tag", "hard drive", "kite", "mint", "toast", "cloths", "birds", "bee", "lamps", "power line", "door", "walnuts", "tents", "parsley", "hallway", "airplane", "helicopter", "butter", "pigs", "cds", "ice maker", "dinosaurs", "safety jacket", "feeder", "pots", "wire", "puddle", "newspaper", "pillowcase", "factory", "rain", "meats", "beer bottle", "chicken breast", "bedspread", "bun", "snow shoe", "cords", "tower", "ice cubes", "crane", "hand dryer", "blender", "cupboard", "flip flops", "dress", "leg", "tattoos", "butterflies", "canoe", "porch", "wok", "robe", "dolphins", "clock", "crust", "jackets", "dumplings", "donuts", "fountain", "snake", "moss", "owl", "geese", "magazines", "pencil", "earphones", "soccer balls", "pencils", "shampoo", "flags", "bell", "taco", "pot", "mat", "bridge", "drum", "life preserver", "frosting", "tea", "shrimp", "crown", "lamp", "artwork", "breakfast", "t-shirt", "hillside", "desert", "policeman", "rice", "rug", "earrings", "pans", "chimney", "coach", "games", "sink", "tomato sauce", "pizza oven", "waffle", "tourist", "hummus", "garnish", "hairbrush", "envelope", "feet", "sunflower", "seat belt", "chickens", "light bulb", "sculpture", "pizza", "trumpet", "pantry", "coffee cup", "food", "lavender", "drain", "lion", "blouse", "coconuts", "wristband", "officers", "flip flop", "players", "bleachers", "sticky notes", "soup", "toilet bowl", "tractor", "outlets", "beach chair", "microwave", "toiletries", "life jacket", "phone", "parent", "palm tree", "mountain", "hands", "church", "frisbee", "batteries", "eyes", "cabinet", "cucumber", "meadow", "crackers", "gun", "river", "food processor", "crab", "seeds", "liquid", "beets", "smoke stack", "forks", "peanut", "vanilla", "tractors", "cake", "pillow", "drawings", "picnic tables", "toilet paper", "license plate", "belt", "waist", "berry", "students", "trees", "fruits", "amusement park", "sweater", "refrigerator", "tank top", "soldier", "boot", "spoons", "kettle", "wines", "cigar", "plantains", "cream", "blinds", "rose", "tree trunk", "fireplace", "wii", "bus driver", "purse", "restaurants", "crosswalk", "spatula", "spectators", "airport", "floor", "binder", "caramel", "coleslaw", "magazine", "raisins", "onion ring", "stones", "baked goods", "buoy", "parmesan cheese", "cookies", "ribs", "sign", "entree", "herbs", "couches", "pump", "corn", "toaster oven", "grape", "guys", "station", "candle", "heels", "farm", "salon", "cupcakes", "museum", "saucer", "battery", "tie", "building", "dish soap", "mushrooms", "backyard", "flower pot", "watermelon", "shops", "oven door", "city", "baking sheet", "papaya", "hills", "cone", "shopping center", "attic", "window frame", "bell tower", "ketchup", "pills", "leopard", "desserts", "banana peel", "folding chair", "elmo", "plants", "dishes", "polo shirt", "person", "chain", "moon", "picnic table", "picture frame", "melons", "necklace", "printers", "symbol", "cake pan", "drape", "squash", "bath towel", "gummy bear", "appetizer", "bathtub", "cross", "stump", "computer desk", "soap bottle", "water", "beer can", "popcorn", "stars", "lake", "monkeys", "sour cream", "coffee cups", "shoes", "charger", "packages", "turbines", "passenger", "sushi", "socks", "bar stools", "meal", "motorcycles", "control panel", "orchid", "shower curtain", "eiffel tower", "seagull", "jeep", "thermometer", "card", "theater", "buoys", "octopus", "snowboard", "peach", "snowsuit", "windshield", "skateboarder", "rolling pin", "customers", "auditorium", "tofu", "school", "cheetah", "serving tray", "sugar", "shoe lace", "dock", "tv stand", "stapler", "bike", "eye", "pants", "gentleman", "fire truck", "cotton dessert", "coat", "laptop", "locomotive", "broth", "baby", "ice cream", "appetizers", "fog", "cheese cube", "tortilla", "beetle", "hedges", "puppy", "pilot", "jet", "hair dryer", "televisions", "shelves", "pajamas", "satellite dish", "fence post", "bookshelves", "dish drainer", "ear", "cooking utensils", "cooking utensil", "leaf", "picture", "turtle", "watermelons", "phones", "backpacks", "marshmallow", "roses", "barn", "horses", "camera", "bottle", "gym", "house", "drawing", "cabbage", "croissants", "lizard", "scooters", "beer cans", "raspberries", "brownie", "towers", "olives", "sea foam", "pepper", "market", "router", "toothpicks", "receipt", "bowls", "sausage", "stuffed bear", "glaze", "blackberries", "cookie", "clock hand", "gown", "pillows", "nightstands", "ingredient", "eggs", "monkey", "duck", "vests", "eye glasses", "soda", "topping", "pizza tray", "drinks", "cabinet doors", "ski lift", "serving dish", "courtyard", "cake stand", "graffiti", "animal", "foil", "word", "sugar packet", "gas stove", "yogurt", "brownies", "boy", "employee", "man", "walkway", "air", "number", "intersection", "mattress", "hamburger", "rock", "sailboat", "team", "ears", "candles", "net", "bubbles", "beverages", "jersey", "mirror", "family", "chef", "teddy bear", "bomb", "ceiling light", "speakers", "grater", "wii controller", "glove", "shorts", "gourd", "magnets", "electric toothbrush", "soap dispensers", "seed", "shower", "notebook", "wine", "face mask", "letter", "biscuits", "hay", "garment", "celery", "french toast", "hippos", "almond", "smoothie", "uniform", "mugs", "garland", "can", "ambulance", "lunch", "ladles", "women", "wine glasses", "alarm clock", "mouse pad", "soda bottle", "pigeons", "seafood", "coral", "butterfly", "jeans", "screen", "mountains", "gravel", "dolphin", "gloves", "life jackets", "deer", "microphone", "boys", "couch", "shelf", "cookbook", "toy", "parachutes", "pub", "boulders", "cowboy", "traffic lights", "vine", "ship", "cloud", "gravy", "blazer", "library", "waffles", "carpet", "flower", "pole", "keyboards", "path", "pine tree", "letters", "cards", "toddler", "tablet", "wildflowers", "kangaroo", "hearts", "cabinet door", "office supplies", "driver", "lounge", "pikachu", "video camera", "fingers", "horn", "policemen", "video games", "raisin", "jacket", "lemons", "dish", "plates", "beak", "chair", "machine", "crumbs", "beef", "peanuts", "cones", "vendors", "tennis ball", "terminal", "mirrors", "boats", "napkin dispenser", "juice", "weapon", "fruit stand", "table", "pizza crust", "garlic", "skis", "scissors", "desktop computer", "wheels", "trash bag", "lip", "dragonfly", "words", "biker", "snowboarder", "beverage", "decoration", "staircase", "tree branches", "paper dispenser", "polar bear", "piercing", "bookcase", "gas station", "luggage", "shuttle", "planter", "pineapple", "muffins", "bats", "beer mug", "notepad", "chickpeas", "numbers", "chopsticks", "obstacle", "cafe", "dispenser", "pesto", "box", "mango", "ceiling", "dining room", "pizza slice", "roof", "kittens", "hair clip"]} \ No newline at end of file diff --git a/useful_lists/random_negatives.txt b/useful_lists/random_negatives.txt new file mode 100644 index 0000000000000000000000000000000000000000..ff05b5d799a6d56f01a2f89e10e95891e85862d7 --- /dev/null +++ b/useful_lists/random_negatives.txt @@ -0,0 +1,2737 @@ +beloved +flower +storm +possession +office +poetry +manager +severe +gregarious +influence +melodic +obvious +wise +discussion +awkward +enlightened +flashy +year +scheme +agreement +disloyal +alert +everlasting +spot +border +screw +republic +independence +rain +step +instructive +fix +impeccable +most +canine +mode +witty +sad +taut +whimsical +keep +side +wiggly +chemical +tricky +view +disgusting +clothes +gorgeous +recording +boring +anxiety +efficiency +windy +theme +escape +defiant +vehicle +lead +peak +uncomfortable +even +weird +dead +elated +prudent +audience +needy +external +message +skirt +extent +crack +aromatic +fine +cigarette +unwilling +effective +term +earth +cloud +similar +couple +giddy +energy +sentence +fly +you +yawning +proper +drama +quick-witted +faithful +exchange +zone +quaint +volume +change +hateful +courage +beer +dance +entertainment +resort +fearless +magazine +ancient +winged +doubt +insubstantial +baby +fishing +functional +another +basket +outlandish +stylish +distinct +smart +hair +heartfelt +quintessential +true +formal +penalty +rough +firm +fan +tight +excited +worse +selection +run +grumpy +medium +service +knowledge +dapper +quarrelsome +thin +trade +ball +hell +tattered +third +worthless +frivolous +conventional +maximum +velvety +capital +version +public +reserve +chief +increase +regular +bar +tremendous +blow +childhood +extraneous +solid +nothing +roof +dog +season +suspicious +dishonest +polished +complete +smug +cold +spicy +sound +setting +will +nose +owner +milk +thing +opening +son +art +lively +lady +early +future +jagged +permission +hole +potato +depth +let +fuel +love +ironclad +combination +irritating +chilly +giving +hideous +fail +name +tennis +busy +slide +limp +spring +wilted +worth +impressionable +agency +passenger +wooden +series +pin +apartment +shower +maintenance +lazy +gifted +panic +incomparable +remarkable +bottle +cluttered +voice +mate +visible +drunk +worry +male +bridge +mammoth +film +flowery +perfumed +load +international +flawed +economics +intentional +wonderful +cancel +reception +dinner +shame +expensive +ship +travel +guard +spare +fake +potable +appropriate +ring +gargantuan +lavish +suggestion +pollution +playful +drab +appointment +command +policy +heavy +worldly +internet +copy +webbed +give +raw +breakable +somewhere +jury +equipment +jealous +master +course +yummy +gleeful +flippant +due +bossy +manufacturer +outside +unusual +bedroom +shock +salt +miss +egg +affectionate +cow +parent +last +partner +plan +direct +check +bonus +immaterial +charge +reflecting +mine +noise +importance +circular +professor +signal +vivacious +obligation +grouchy +nimble +mistake +broad +frizzy +watch +buttery +well-to-do +cooperative +mood +dense +evergreen +husky +street +opposite +adolescent +television +yearly +level +good +final +lustrous +employ +frayed +addition +advantage +absolute +native +cookie +replacement +raise +appeal +unsung +buyer +that +self-reliant +purpose +remorseful +trainer +handy +red +focused +glistening +north +league +honorable +nippy +lawyer +sleep +sick +dull +resident +author +strength +gigantic +bronze +fact +reputation +inborn +transition +ordinary +throat +king +humiliating +miniature +pretty +idolized +top +security +library +letter +primary +impractical +dream +noisy +bus +event +climate +frosty +cloudy +same +class +union +sample +player +partial +stroke +trip +brain +window +button +skin +money +drive +jittery +gullible +striped +cabinet +extra-large +greedy +dazzling +period +funny +knowledgeable +lawful +cause +fabulous +neglected +icy +possibility +club +cool +resolve +meal +risk +zany +fumbling +outgoing +type +salary +set +arid +content +breakfast +gap +swimming +position +foolhardy +roasted +virtuous +empty +subtle +title +put +turbulent +distance +grass +comfort +economy +glamorous +past +uniform +false +insignificant +fitting +pessimistic +mission +submissive +parched +darling +motor +only +school +plush +consequence +harmless +memorable +pristine +pair +parallel +villainous +grade +refrigerator +card +frame +sex +delectable +sandy +farmer +vicious +degree +harmful +guess +ignorant +costly +understanding +afraid +crash +judicious +laugh +untried +cable +tank +steak +muted +annual +debt +imperturbable +cruel +chart +humble +surround +those +birth +growing +ugly +leg +productive +link +intent +hang +tragic +horrible +far +power +scene +selfish +lie +status +contact +wild +mother +enraged +odd +prize +motionless +far-off +department +hurt +depression +insect +shine +estimate +creative +kind +orange +noxious +scientific +pie +gate +driver +example +steep +recipe +length +draft +tepid +beach +angry +pattern +key +political +concrete +scratchy +strident +bony +studio +improvement +rich +record +wonder +tip +singer +teach +confusion +gross +middle +super +upper +winter +tomorrow +optimistic +wasteful +temporary +emotional +stage +several +bold +man +favorite +release +attractive +realistic +duty +tone +restaurant +plane +history +aware +hard +opportunity +determined +father +noteworthy +rush +analyst +leafy +well-groomed +supportive +brick +daughter +grandmother +head +sizzling +angelic +cylindrical +yellowish +surprised +beneficial +description +deafening +complaint +exit +staff +splendid +flickering +monthly +wine +line +concert +wary +food +verifiable +toe +rent +repulsive +metal +lumbering +bunch +hook +modern +career +thunderous +situation +crowded +nation +wood +fresh +book +staid +vital +claim +hefty +problem +bid +sister +boy +initiative +frigid +inconsequential +mention +profuse +tourist +theory +bath +excitable +scaly +front +monumental +dare +thorough +spray +bare +daring +upstairs +lay +go +two +ill-informed +guitar +resolution +politics +courageous +news +little +able +party +smile +half +scarce +bicycle +menacing +pay +agitated +enchanted +useless +communication +metallic +single +purple +action +ear +slow +abandoned +accurate +excuse +intrepid +blood +family +hearing +multicolored +rotten +sympathetic +monitor +misguided +competition +weak +grateful +script +esteemed +academic +grimy +friend +adventurous +flat +boat +piercing +flawless +following +voluminous +secret +disease +skeletal +pop +unwritten +moist +usual +accomplished +improbable +wing +lighthearted +curvy +switch +guarantee +rule +short-term +grandfather +useful +police +punch +tale +incompatible +mature +section +inspector +meek +member +striking +unselfish +airline +colorful +look +current +hurtful +united +unacceptable +messy +organization +gracious +sinful +child +march +mud +sail +left +peaceful +boiling +location +fault +pleased +map +place +coordinated +recommendation +pure +juvenile +optimal +harmonious +lab +candy +proud +loving +ability +meaning +secretary +mad +snarling +weepy +arrival +struggle +unlawful +complicated +passage +interaction +mixed +pesky +known +regret +kid +jacket +French +information +concern +highway +wheel +memory +buddy +feedback +untidy +abuse +extreme +extroverted +prompt +suspect +hate +mealy +might +scented +doctor +brother +cry +fold +bustling +excitement +number +cash +brown +altruistic +meaty +difficult +log +bitter +sorrowful +life +gummy +amount +paper +work +double +permit +judge +administration +oval +notice +devil +revenue +operation +tart +antique +engineer +note +overjoyed +live +klutzy +every +survey +gripping +plastic +one +outrageous +split +picture +score +high-level +basis +employee +plump +answer +people +nautical +soggy +radiant +horse +famous +distant +solution +aching +colossal +strip +sure-footed +responsible +shift +friendly +thoughtful +chocolate +mixture +glass +candid +mild +creepy +presentation +membership +definitive +honored +curly +warped +fussy +church +software +bug +misty +listen +professional +oblong +goal +descriptive +bag +apple +puny +activity +upbeat +weighty +livid +guy +pound +area +stock +favorable +delicious +understated +linear +pale +sir +massive +independent +vigorous +aggressive +simplistic +tall +cup +superior +homework +tree +land +relation +overcooked +reading +aspect +liquid +extension +challenge +judgment +squiggly +wobbly +tinted +rise +superficial +hand +urban +neighboring +commission +made-up +wealth +control +initial +start +associate +indication +dealer +big-hearted +chest +regal +arm +piece +lesson +possible +student +motherly +sturdy +objective +advice +practice +shopping +joyful +mom +monstrous +crushing +ripe +impassioned +kiss +worker +engineering +jovial +imaginative +sniveling +young +animal +feline +decent +royal +edible +farm +fixed +kick +jumbo +clerk +mind +blue +web +bottom +forever +admired +progress +husband +respect +horror +roll +fall +petty +round +lecture +idle +worn +star +emergency +fight +quarter +business +talk +periodic +low +definition +wealthy +smoke +focus +quizzical +wretched +infatuated +disk +device +figure +function +wry +crafty +svelte +potential +board +speedy +weight +writer +argument +illiterate +measurement +development +bountiful +base +rock +puzzled +cheap +part +drawer +shy +tie +finger +variation +normal +negative +bad +joint +patient +code +force +case +ample +trusting +pace +hotel +display +bubbly +flustered +acceptable +reason +departure +expert +sudden +diligent +cultivated +belated +act +revolution +whirlwind +rectangular +promise +sardonic +disfigured +dark +delayed +dimpled +unconscious +sing +animated +lack +gather +light +repair +preparation +biodegradable +turn +idea +impression +corrupt +play +boss +treasured +torn +corner +delivery +stupid +night +trustworthy +tradition +gift +still +adult +deadly +simple +many +girl +till +box +explanation +tongue +yard +personality +dish +neat +brush +poor +alcohol +affair +courteous +month +culture +appearance +piano +complex +peace +project +parking +white +conflict +show +perception +birthday +demanding +wife +illegal +hat +coffee +silly +ringed +promotion +majestic +chain +state +extra-small +match +imperfect +schedule +mobile +prior +west +bit +youthful +natural +question +gold +narrow +violent +blushing +testy +fun +sock +old +dramatic +spectacular +size +energetic +enormous +trouble +shape +feed +alive +hit +frightening +relative +impish +marvelous +rare +shrill +careless +refuse +muscle +balance +untimely +tooth +phrase +enthusiasm +patience +care +commercial +loyal +watery +bright +computer +dangerous +grab +legal +frightened +injury +spiritual +ground +innocent +utilized +tell +distribution +knotty +worst +stark +rice +recent +ecstatic +scary +bread +attempt +towel +earnest +electric +sore +hoarse +cute +idiotic +watchful +exalted +squeaky +slip +close +spite +burn +anguished +sell +traffic +net +instance +lumpy +these +all +apprehensive +miserable +sarcastic +terrific +frozen +previous +writhing +connection +scholarly +drafty +fee +hour +starry +chip +flimsy +elliptical +big +she +door +twist +fancy +consideration +range +water +treat +unruly +immediate +battle +envious +incomplete +kooky +join +package +responsibility +lock +vacation +scratch +new +strike +hasty +nervous +plaintive +informal +exercise +chair +media +clever +trusty +garage +representative +brave +item +inferior +first +wavy +weekly +percentage +artistic +necessary +tough +sense +session +rip +short +precious +peppery +opulent +closed +homely +major +open +mark +ill-fated +push +virus +stand +factor +beginning +south +general +meager +lip +week +spiteful +zesty +unpleasant +unlined +thanks +superb +identical +detailed +day +harsh +building +secondary +flaky +reward +worthy +pizza +user +brilliant +wordy +virtual +adored +priest +impartial +newspaper +tower +housing +old-fashioned +ease +mellow +unsightly +loss +oily +sane +nonstop +flow +director +mountainous +dependent +drawing +black-and-white +body +tan +read +reliable +face +clock +context +acrobatic +chapter +dear +sharp +next +willing +well-off +nutritious +jump +ready +probable +graceful +present +austere +tonight +share +seat +uneven +pushy +car +bake +model +passionate +grown +consist +blaring +foundation +pitiful +failure +jam-packed +blond +wrathful +gaseous +finance +candidate +prime +putrid +calculating +minor +bite-sized +budget +can +account +quick +bone +sector +pleasing +thought +ride +growth +campaign +data +surgery +snoopy +upset +anchored +medical +assured +agile +posh +incident +reckless +hello +trifling +candle +taste +leading +exciting +advance +chicken +sparse +quixotic +assignment +diet +vegetable +bouncy +transportation +vacant +starchy +society +concentrate +row +brief +neck +ask +cell +dig +proposal +technology +grizzled +junior +caring +weather +back +manner +park +dental +teacher +special +questionable +qualified +deposit +faint +perky +win +respectful +lost +which +respond +boyfriend +issue +easy +hope +composed +sink +poised +machine +unfolded +magnificent +dimension +weakness +youth +intelligent +fortunate +stick +bother +thick +stale +disastrous +portly +wall +garden +speech +colorless +hidden +camp +price +incredible +stimulating +teeming +our +cumbersome +stingy +tired +presence +wave +warlike +officer +county +oddball +steel +dizzy +guest +resist +second-hand +outstanding +awful +pack +fortune +god +review +upright +quantity +fill +long +suck +east +dutiful +counter +click +utter +elastic +vapid +burdensome +wait +vigilant +inspection +luck +painting +inexperienced +fruit +bed +silent +inflation +truth +test +faraway +baseball +kindly +quality +unequaled +whole +signature +nerve +shake +unkempt +rosy +wiry +mediocre +feminine +unsteady +register +actual +foot +decision +grave +broken +joke +discipline +lanky +bill +crisp +interview +craft +spirited +immense +hot +country +burly +gear +fond +reference +jaunty +district +adept +remove +attention +introduction +unfit +vast +poet +education +nurse +trained +eye +mess +bewitched +cycle +friendship +grubby +bear +vain +quarterly +equal +mundane +field +order +bruised +eat +long-term +village +management +generous +corny +green +troubled +late +house +association +essay +group +mirror +unwieldy +knife +lone +deserted +smoggy +dopey +garbage +shot +sand +crooked +immaculate +equivalent +ragged +construction +forceful +awareness +hilarious +plate +particular +form +gentle +use +married +freedom +grocery +amazing +unhappy +difficulty +estate +lean +hopeful +role +kitchen +definite +safety +help +movie +stranger +paint +shirt +sale +decimal +article +different +lift +national +negligible +glove +stay +desk +heat +database +story +frilly +exam +gene +punctual +application +prickly +station +infinite +slight +fair +thirsty +bowl +bat +cooked +yesterday +witness +both +feature +scornful +glaring +spread +curve +convert +report +truthful +fruitful +main +firsthand +woozy +cook +payment +acclaimed +helpful +lake +limit +positive +tackle +high +ashamed +automatic +heavenly +resource +shoe +knobby +tap +maybe +usable +assistant +menu +deficient +whereas +shiny +math +few +employment +specific +pull +remote +anything +conclusion +advanced +agonizing +elevator +examination +unique +stupendous +satisfied +infantile +offensive +imagination +client +passion +string +senior +property +dirty +teaching +spiffy +snow +juicy +impure +program +adorable +statement +store +haunting +downright +interesting +hollow +decisive +critical +square +dust +airport +female +reasonable +unwelcome +finish +university +world +pass +track +carry +flamboyant +admirable +layer +queen +profitable +stuff +difference +band +flight +careful +sleepy +condition +unwitting +cut +try +fluid +system +miserly +fuzzy +measly +afternoon +bet +market +customer +considerate +shoulder +soulful +minty +drag +visual +impress +island +awesome +dirt +target +disaster +marriage +quote +home +eminent +reveal +walk +pain +impolite +good-natured +imaginary +repentant +topic +bland +prestigious +digital +investment +criminal +thorny +chubby +attitude +image +while +mouth +forsaken +philosophy +coach +wind +rate +method +original +ladder +sympathy +spell +storage +handmade +milky +make +scrawny +lined +bank +golf +handsome +unrealistic +concept +detail +ambition +frail +smell +powerful +blind +rowdy +nature +talkative +wee +limited +near +phony +frugal +pleasant +popular +telephone +stunning +shadowy +performance +unfortunate +skill +grim +indelible +today +satisfaction +ajar +government +leader +pick +naughty +interest +designer +anger +living +spotless +timely +sweaty +soft +spend +science +tame +Spanish +buzzing +deep +best +psychology +sweet +filthy +shabby +video +oil +outlying +game +worthwhile +dot +some +coat +editor +lovable +block +obese +insidious +wash +mortified +unhealthy +election +contest +thrifty +bumpy +ideal +ultimate +right +hall +pot +beautiful +college +result +avaricious +rural +nasty +spirit +swift +preference +collar +pointless +affect +source +irresponsible +heart +delirious +untrue +athletic +arctic +city +criticism +shimmering +engine +perfect +scared +authentic +individual +mortgage +mix +offer +road +overdue +agent +job +routine +evening +defenseless +disguised +award +stretch +currency +gloomy +direction +distorted +slim +serpentine +delay +sugary +helpless +forthright +it +active +other +euphoric +bowed +blame +age +literature +easy-going +being +ill +fearful +tea +ice +variety +sunny +sophisticated +finished +private +triangular +cross +polite +swing +finding +free +sensitive +failing +feeling +crew +town +time +fluffy +word +wrap +plenty +product +lonely +shameless +stiff +experienced +race +visit +network +used +alarm +effect +elderly +compassionate +loud +mall +practical +song +wear +midnight +negotiation +relief +search +tangible +loan +suburban +train +clumsy +photo +confidence +discount +sky +small +draw +legitimate +environment +design +hire +guide +growling +somber +rotating +gruesome +mushy +collection +minimum +rigid +rope +sour +point +glad +aggravating +accident +obedient +actor +alarming +address +second +alarmed +basic +icky +shop +cheek +cuddly +health +cavernous +queasy +woman +limping +return +giant +temperature +delightful +devoted +unfinished +indolent +bell +venerated +slice +entire +querulous +print +poem +mountain +essential +impact +required +forked +leadership +notable +guilty +muffled +specialist +tax +hide +champion +jumpy +fatal +belt +tiny +oven +noted +dump +real +elaborate +recognition +standard +serve +unimportant +truck +lame +funeral +ticket +death +response +well-lit +cultured +style +sandwich +loathsome +list +quirky +silky +kaleidoscopic +swim +hold +self +wide +landscape +inevitable +profile +tool +vibrant +beyond +uncommon +rundown +pricey +shoot +law +luxurious +region +site +defensive +sugar +modest +welcome +character +request +river +moment +medicine +breast +pertinent +grotesque +table +clean +trivial +loose +thankful +experience +dimwitted +smooth +yellow +human +competent +exotic +divide +minute +instruction +studious +threadbare +numb +elementary +cover +value +exhausted +holiday +priority +agreeable +astonishing +ruin +slippery +reaction +likely +feel +rubbery +official +babyish +radio +cancer +grandiose +authorized +soupy +research +outcome +comfortable +press +charity +cap +wake +speaker +working +produce +cake +quiet +this +humming +worrisome +average +fantastic +crazy +cost +employer +sneaky +intention +total +closet +shoddy +camera +strange +grounded +vengeful +fatherly +weekend +guidance +glossy +branch +mindless +strict +entrance +path +moral +sparkling +glittering +grand +luminous +summer +industry +lunch +carefree +boot +atmosphere +relieved +well-worn +combine +save +leather +rude +rub +sweltering +infamous +amused +important +frank +bathroom +runny +gas +hearty +great +ruddy +effort +exemplary +salty +access +twin +break +likable +concerned +nifty +president +subdued +well-documented +wet +football +mail +music +insecure +pension +ornate +comment +blank +bleak +dry +war +classroom +straight +attack +creamy +valid +soup +strong +profit +choice +genuine +hospital +analysis +wide-eyed +juice +alienated +eager +droopy +way +stomach +common +nobody +itchy +assumption +angle +quit +subject +text +zealous +entry +touch +tasty +mean +dreary +black +feisty +income +idealistic +reach +insurance +kindhearted +brisk +familiar +waterlogged +object +move +file +training +savings +enchanting +trust +stormy +historian +discrete +morning +cheery +debate +zigzag +expression +army +occasion +worried +soil +bend +emotion +dress +weary +comparison +acidic +gain +honest +speed +coast +coarse +equatorial +contract +conscious +person +chemistry +date +principle +sign +emphasis +pastel +revolving +plain +contribution +surprise +vague +company +wish +height +pungent +evidence +glorious +baggy +huge +pitch +opinion +unknown +embarrassed +physics +gleaming +damaged +sport +rusty +fear +anywhere +steal +significance +mouse +trash +room +cheerful +drop +merry +kosher +dismal +space +gray +damp +trim +repeat +keen +shady +woeful +fickle +warning +stained +novel +valuable +any +platform +reflection +tense +hungry +breath +recover +tempting +palatable +option +page +separate +violet +stress +drink +cream +supermarket +unit +fragrant +granular +hospitable +elegant +success +tune +strategy +lovely +full +wrong +tubby +shocking +attentive +safe +paltry +charming +spherical +dad +naive +pen +humongous +fast +unnatural +tour +ethical +dearest +sun +count +impossible +fat +ratio +calendar +aside +team +suit +channel +better +celebration +conversation +pool +procedure +reply +salad +handle +bite +material +dependable +phase +knowing +leave +serene +rest +excellent +catch +inside +anybody +purchase +strain +clear +orderly +robust +writing +dual +structure +iron +study +well-made +demand +invite +happy-go-lucky +slushy +kill +hard-to-find +whopping +phone +sort +organic +if +confused +wan +winding +lucky +deal +local +establishment +unripe +sticky +pointed +tedious +traumatic +rash +occasional +anxious +court +clue +variable +beat +spry +error +pipe +trick +harm +illustrious +wicked +hunt +meat +insistent +personal +task +scale +unlucky +impressive +call +sociable +ad +amusing +physical +edge +chance +stable +celebrated +fish +rapid +buy +constant +host +lopsided +victorious +healthy +unaware +cautious +nocturnal +blissful +bird +hairy +warm +support +substance +murky +uncle +tension +clueless +unused +buoyant +internal +evil +ambitious +advertising +mysterious +bike +golden +slimy +embellished +matter +relationship +clear-cut +profession +floor +post +background +meet +western +cousin +barren +pressure +tender +glum +educated +pride +shocked +puzzling +highlight +jolly +muddy +shameful +habit +community +perspective +showy +substantial +meeting +nail +shelter +jubilant +wedding +tear +skinny +masculine +alternative +silver +protection +knee +happy +serious +damage +attached +marketing +calm +credit +air +category +reality +musty +nice +aged +pleasure +pause +each +bulky +screen +plant +tidy +fire +joyous +championship +rewarding +requirement +sea +language +snappy +overlooked +spotted +dim +end +classic +cat +implement +frequent +ornery +diamond +jaded +population +apt +honey +bench +nutty +assistance +winner +large +shallow +self-assured +powerless +terrible +far-flung +foolish +abroad +process +whispered +stop +carpet +desire +sentimental +vivid +lasting +well-informed +hurry +assist +bogus +committee +doting +benefit +conference +proof +warmhearted +offbeat +document +girlfriend +pink diff --git a/vision_models.py b/vision_models.py new file mode 100644 index 0000000000000000000000000000000000000000..c16e28d16f36ca2c91a3c2be4e08891280648124 --- /dev/null +++ b/vision_models.py @@ -0,0 +1,745 @@ +import abc +import os +import re +import timeit +from typing import Union + +import torch +import torchvision +from PIL import Image +from torch import hub +from torch.nn import functional as F +from torchvision import transforms + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +class BaseModel(abc.ABC): + to_batch = False + seconds_collect_data = 1.5 # Window of seconds to group inputs, if to_batch is True + max_batch_size = 10 # Maximum batch size, if to_batch is True. Maximum allowed by OpenAI + requires_gpu = True + num_gpus = 1 # Number of required GPUs + load_order = 0 # Order in which the model is loaded. Lower is first. By default, models are loaded alphabetically + + def __init__(self, gpu_number): + self.dev = f'cuda:{gpu_number}' if device == 'cuda' else device + + @abc.abstractmethod + def forward(self, *args, **kwargs): + """ + If to_batch is True, every arg and kwarg will be a list of inputs, and the output should be a list of outputs. + The way it is implemented in the background, if inputs with defaults are not specified, they will take the + default value, but still be given as a list to the forward method. + """ + pass + + @classmethod + @abc.abstractmethod + def name(cls) -> str: + """The name of the model has to be given by the subclass""" + pass + + @classmethod + def list_processes(cls): + """ + A single model can be run in multiple processes, for example if there are different tasks to be done with it. + If multiple processes are used, override this method to return a list of strings. + Remember the @classmethod decorator. + If we specify a list of processes, the self.forward() method has to have a "process_name" parameter that gets + automatically passed in. + See GPT3Model for an example. + """ + return [cls.name] + + +# ------------------------------ Specific models ---------------------------- # + + +class ObjectDetector(BaseModel): + name = 'object_detector' + + def __init__(self, gpu_number=0): + super().__init__(gpu_number) + + detection_model = hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True).to(self.dev) + detection_model.eval() + + self.detection_model = detection_model + + @torch.no_grad() + def forward(self, image: torch.Tensor): + """get_object_detection_bboxes""" + input_batch = image.to(self.dev).unsqueeze(0) # create a mini-batch as expected by the model + detections = self.detection_model(input_batch) + p = detections['pred_boxes'] + p = torch.stack([p[..., 0], 1 - p[..., 3], p[..., 2], 1 - p[..., 1]], -1) # [left, lower, right, upper] + detections['pred_boxes'] = p + return detections + + +class DepthEstimationModel(BaseModel): + name = 'depth' + + def __init__(self, gpu_number=0, model_type='DPT_Large'): + super().__init__(gpu_number) + # Model options: MiDaS_small, DPT_Hybrid, DPT_Large + depth_estimation_model = hub.load('intel-isl/MiDaS', model_type, pretrained=True).to(self.dev) + depth_estimation_model.eval() + + midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") + + if model_type == "DPT_Large" or model_type == "DPT_Hybrid": + self.transform = midas_transforms.dpt_transform + else: + self.transform = midas_transforms.small_transform + + self.depth_estimation_model = depth_estimation_model + + @torch.no_grad() + def forward(self, image: torch.Tensor): + """Estimate depth map""" + image_numpy = image.cpu().permute(1, 2, 0).numpy() * 255 + input_batch = self.transform(image_numpy).to(self.dev) + prediction = self.depth_estimation_model(input_batch) + # Resize to original size + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=image_numpy.shape[:2], + mode="bicubic", + align_corners=False, + ).squeeze() + # We compute the inverse because the model returns inverse depth + to_return = 1 / prediction + to_return = to_return.cpu() + return to_return # To save: plt.imsave(path_save, prediction.cpu().numpy()) + + +class CLIPModel(BaseModel): + name = 'clip' + + def __init__(self, gpu_number=0, version="ViT-L/14@336px"): # @336px + super().__init__(gpu_number) + + import clip + self.clip = clip + + model, preprocess = clip.load(version, device=self.dev) + model.eval() + model.requires_grad_ = False + self.model = model + self.negative_text_features = None + self.transform = self.get_clip_transforms_from_tensor(336 if "336" in version else 224) + + # @staticmethod + def _convert_image_to_rgb(self, image): + return image.convert("RGB") + + # @staticmethod + def get_clip_transforms_from_tensor(self, n_px=336): + return transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize(n_px, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(n_px), + self._convert_image_to_rgb, + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + @torch.no_grad() + def binary_score(self, image: torch.Tensor, prompt, negative_categories=None): + is_video = isinstance(image, torch.Tensor) and image.ndim == 4 + if is_video: # video + image = torch.stack([self.transform(image[i]) for i in range(image.shape[0])], dim=0) + else: + image = self.transform(image).unsqueeze(0).to(self.dev) + + prompt_prefix = "photo of " + prompt = prompt_prefix + prompt + + if negative_categories is None: + if self.negative_text_features is None: + self.negative_text_features = self.clip_negatives(prompt_prefix) + negative_text_features = self.negative_text_features + else: + negative_text_features = self.clip_negatives(prompt_prefix, negative_categories) + + text = self.clip.tokenize([prompt]).to(self.dev) + + image_features = self.model.encode_image(image.to(self.dev)) + image_features = F.normalize(image_features, dim=-1) + + pos_text_features = self.model.encode_text(text) + pos_text_features = F.normalize(pos_text_features, dim=-1) + + text_features = torch.concat([pos_text_features, negative_text_features], axis=0) + + # run competition where we do a binary classification + # between the positive and all the negatives, then take the mean + sim = (100.0 * image_features @ text_features.T).squeeze(dim=0) + if is_video: + query = sim[..., 0].unsqueeze(-1).broadcast_to(sim.shape[0], sim.shape[-1] - 1) + others = sim[..., 1:] + res = F.softmax(torch.stack([query, others], dim=-1), dim=-1)[..., 0].mean(-1) + else: + res = F.softmax(torch.cat((sim[0].broadcast_to(1, sim.shape[0] - 1), + sim[1:].unsqueeze(0)), dim=0), dim=0)[0].mean() + return res + + @torch.no_grad() + def clip_negatives(self, prompt_prefix, negative_categories=None): + if negative_categories is None: + with open('useful_lists/random_negatives.txt') as f: + negative_categories = [x.strip() for x in f.read().split()] + # negative_categories = negative_categories[:1000] + # negative_categories = ["a cat", "a lamp"] + negative_categories = [prompt_prefix + x for x in negative_categories] + negative_tokens = self.clip.tokenize(negative_categories).to(self.dev) + + negative_text_features = self.model.encode_text(negative_tokens) + negative_text_features = F.normalize(negative_text_features, dim=-1) + + return negative_text_features + + @torch.no_grad() + def classify(self, image: Union[torch.Tensor, list], categories: list[str], return_index=True): + is_list = isinstance(image, list) + if is_list: + assert len(image) == len(categories) + image = [self.transform(x).unsqueeze(0) for x in image] + image_clip = torch.cat(image, dim=0).to(self.dev) + elif len(image.shape) == 3: + image_clip = self.transform(image).to(self.dev).unsqueeze(0) + else: # Video (process images separately) + image_clip = torch.stack([self.transform(x) for x in image], dim=0).to(self.dev) + + # if len(image_clip.shape) == 3: + # image_clip = image_clip.unsqueeze(0) + + prompt_prefix = "photo of " + categories = [prompt_prefix + x for x in categories] + categories = self.clip.tokenize(categories).to(self.dev) + + text_features = self.model.encode_text(categories) + text_features = F.normalize(text_features, dim=-1) + + image_features = self.model.encode_image(image_clip) + image_features = F.normalize(image_features, dim=-1) + + if image_clip.shape[0] == 1: + # get category from image + softmax_arg = image_features @ text_features.T # 1 x n + else: + if is_list: + # get highest category-image match with n images and n corresponding categories + softmax_arg = (image_features @ text_features.T).diag().unsqueeze(0) # n x n -> 1 x n + else: + softmax_arg = (image_features @ text_features.T) + + similarity = (100.0 * softmax_arg).softmax(dim=-1).squeeze(0) + if not return_index: + return similarity + else: + result = torch.argmax(similarity, dim=-1) + if result.shape == (): + result = result.item() + return result + + @torch.no_grad() + def compare(self, images: list[torch.Tensor], prompt, return_scores=False): + images = [self.transform(im).unsqueeze(0).to(self.dev) for im in images] + images = torch.cat(images, dim=0) + + prompt_prefix = "photo of " + prompt = prompt_prefix + prompt + + text = self.clip.tokenize([prompt]).to(self.dev) + + image_features = self.model.encode_image(images.to(self.dev)) + image_features = F.normalize(image_features, dim=-1) + + text_features = self.model.encode_text(text) + text_features = F.normalize(text_features, dim=-1) + + sim = (image_features @ text_features.T).squeeze(dim=-1) # Only one text, so squeeze + + if return_scores: + return sim + res = sim.argmax() + return res + + def forward(self, image, prompt, task='score', return_index=True, negative_categories=None, return_scores=False): + if task == 'classify': + categories = prompt + clip_sim = self.classify(image, categories, return_index=return_index) + out = clip_sim + elif task == 'score': + clip_score = self.binary_score(image, prompt, negative_categories=negative_categories) + out = clip_score + else: # task == 'compare' + idx = self.compare(image, prompt, return_scores) + out = idx + if not isinstance(out, int): + out = out.cpu() + return out + + +class MaskRCNNModel(BaseModel): + name = 'maskrcnn' + + def __init__(self, gpu_number=0, threshold=0.8): + super().__init__(gpu_number) + obj_detect = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights='COCO_V1').to(self.dev) + obj_detect.eval() + obj_detect.requires_grad_(False) + self.categories = torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1.meta['categories'] + self.obj_detect = obj_detect + self.threshold = threshold + + def prepare_image(self, image): + image = image.to(self.dev) + return image + + @torch.no_grad() + def detect(self, images: torch.Tensor, confidence_threshold: float = None): + if type(images) != list: + images = [images] + threshold = confidence_threshold if confidence_threshold is not None else self.threshold + + images = [self.prepare_image(im) for im in images] + detections = self.obj_detect(images) + scores = [] + for i in range(len(images)): + scores.append(detections[i]['scores'][detections[i]['scores'] > threshold]) + + height = detections[i]['masks'].shape[-2] + # Just return boxes (no labels no masks, no scores) with scores > threshold + d_i = detections[i]['boxes'][detections[i]['scores'] > threshold] + # Return [left, lower, right, upper] instead of [left, upper, right, lower] + detections[i] = torch.stack([d_i[:, 0], height - d_i[:, 3], d_i[:, 2], height - d_i[:, 1]], dim=1) + + return detections, scores + + def forward(self, image, confidence_threshold: float = None): + obj_detections, obj_scores = self.detect(image, confidence_threshold=confidence_threshold) + # Move to CPU before sharing. Alternatively we can try cloning tensors in CUDA, but may not work + obj_detections = [(v.to('cpu') if isinstance(v, torch.Tensor) else list(v)) for v in obj_detections] + obj_scores = [(v.to('cpu') if isinstance(v, torch.Tensor) else list(v)) for v in obj_scores] + return obj_detections, obj_scores + + +class GLIPModel(BaseModel): + name = 'glip' + + def __init__(self, model_size='large', gpu_number=0, *args): + BaseModel.__init__(self, gpu_number) + + # with contextlib.redirect_stderr(open(os.devnull, "w")): # Do not print nltk_data messages when importing + from maskrcnn_benchmark.engine.predictor_glip import GLIPDemo, to_image_list, create_positive_map, \ + create_positive_map_label_to_token_from_positive_map + + working_dir = 'pretrained_models/GLIP/' + if model_size == 'tiny': + config_file = working_dir + "configs/glip_Swin_T_O365_GoldG.yaml" + weight_file = working_dir + "checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth" + else: # large + config_file = working_dir + "configs/glip_Swin_L.yaml" + weight_file = working_dir + "checkpoints/glip_large_model.pth" + + class OurGLIPDemo(GLIPDemo): + + def __init__(self, dev, *args_demo): + + kwargs = { + 'min_image_size': 800, + 'confidence_threshold': 0.5, + 'show_mask_heatmaps': False + } + + self.dev = dev + + from maskrcnn_benchmark.config import cfg + + # manual override some options + cfg.local_rank = 0 + cfg.num_gpus = 1 + cfg.merge_from_file(config_file) + cfg.merge_from_list(["MODEL.WEIGHT", weight_file]) + cfg.merge_from_list(["MODEL.DEVICE", self.dev]) + + from transformers.utils import logging + + logging.set_verbosity_error() + GLIPDemo.__init__(self, cfg, *args_demo, **kwargs) + if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD": + plus = 1 + else: + plus = 0 + self.plus = plus + self.color = 255 + + @torch.no_grad() + def compute_prediction(self, original_image, original_caption, custom_entity=None): + image = self.transforms(original_image) + # image = [image, image.permute(0, 2, 1)] + image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY) + image_list = image_list.to(self.dev) + # caption + if isinstance(original_caption, list): + + if len(original_caption) > 40: + all_predictions = None + for loop_num, i in enumerate(range(0, len(original_caption), 40)): + list_step = original_caption[i:i + 40] + prediction_step = self.compute_prediction(original_image, list_step, custom_entity=None) + if all_predictions is None: + all_predictions = prediction_step + else: + # Aggregate predictions + all_predictions.bbox = torch.cat((all_predictions.bbox, prediction_step.bbox), dim=0) + for k in all_predictions.extra_fields: + all_predictions.extra_fields[k] = \ + torch.cat((all_predictions.extra_fields[k], + prediction_step.extra_fields[k] + loop_num), dim=0) + return all_predictions + + # we directly provided a list of category names + caption_string = "" + tokens_positive = [] + seperation_tokens = " . " + for word in original_caption: + tokens_positive.append([len(caption_string), len(caption_string) + len(word)]) + caption_string += word + caption_string += seperation_tokens + + tokenized = self.tokenizer([caption_string], return_tensors="pt") + # tokens_positive = [tokens_positive] # This was wrong + tokens_positive = [[v] for v in tokens_positive] + + original_caption = caption_string + # print(tokens_positive) + else: + tokenized = self.tokenizer([original_caption], return_tensors="pt") + if custom_entity is None: + tokens_positive = self.run_ner(original_caption) + # print(tokens_positive) + # process positive map + positive_map = create_positive_map(tokenized, tokens_positive) + + positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, + plus=self.plus) + self.positive_map_label_to_token = positive_map_label_to_token + tic = timeit.time.perf_counter() + + # compute predictions + predictions = self.model(image_list, captions=[original_caption], + positive_map=positive_map_label_to_token) + predictions = [o.to(self.cpu_device) for o in predictions] + # print("inference time per image: {}".format(timeit.time.perf_counter() - tic)) + + # always single image is passed at a time + prediction = predictions[0] + + # reshape prediction (a BoxList) into the original image size + height, width = original_image.shape[-2:] + # if self.tensor_inputs: + # else: + # height, width = original_image.shape[:-1] + prediction = prediction.resize((width, height)) + + if prediction.has_field("mask"): + # if we have masks, paste the masks in the right position + # in the image, as defined by the bounding boxes + masks = prediction.get_field("mask") + # always single image is passed at a time + masks = self.masker([masks], [prediction])[0] + prediction.add_field("mask", masks) + + return prediction + + @staticmethod + def to_left_right_upper_lower(bboxes): + return [(bbox[1], bbox[3], bbox[0], bbox[2]) for bbox in bboxes] + + @staticmethod + def to_xmin_ymin_xmax_ymax(bboxes): + # invert the previous method + return [(bbox[2], bbox[0], bbox[3], bbox[1]) for bbox in bboxes] + + @staticmethod + def prepare_image(image): + image = image[[2, 1, 0]] # convert to bgr for opencv-format for glip + return image + + @torch.no_grad() + def forward(self, image: torch.Tensor, obj: Union[str, list], confidence_threshold=None): + if confidence_threshold is not None: + original_confidence_threshold = self.confidence_threshold + self.confidence_threshold = confidence_threshold + + # if isinstance(object, list): + # object = ' . '.join(object) + ' .' # add separation tokens + image = self.prepare_image(image) + + # Avoid the resizing creating a huge image in a pathological case + ratio = image.shape[1] / image.shape[2] + ratio = max(ratio, 1 / ratio) + original_min_image_size = self.min_image_size + if ratio > 10: + self.min_image_size = int(original_min_image_size * 10 / ratio) + self.transforms = self.build_transform() + + with torch.cuda.device(self.dev): + inference_output = self.inference(image, obj) + + bboxes = inference_output.bbox.cpu().numpy().astype(int) + # bboxes = self.to_left_right_upper_lower(bboxes) + + if ratio > 10: + self.min_image_size = original_min_image_size + self.transforms = self.build_transform() + + bboxes = torch.tensor(bboxes) + + # Convert to [left, lower, right, upper] instead of [left, upper, right, lower] + height = image.shape[-2] + bboxes = torch.stack([bboxes[:, 0], height - bboxes[:, 3], bboxes[:, 2], height - bboxes[:, 1]], dim=1) + + if confidence_threshold is not None: + self.confidence_threshold = original_confidence_threshold + + # subtract 1 because it's 1-indexed for some reason + # return bboxes, inference_output.get_field("labels").cpu().numpy() - 1 + return bboxes, inference_output.get_field("scores") + + self.glip_demo = OurGLIPDemo(*args, dev=self.dev) + + def forward(self, *args, **kwargs): + return self.glip_demo.forward(*args, **kwargs) + + +class BLIPModel(BaseModel): + name = 'blip' + to_batch = True + max_batch_size = 32 + seconds_collect_data = 0.2 # The queue has additionally the time it is executing the previous forward pass + + def __init__(self, gpu_number=0, half_precision=True, blip_v2_model_type="blip2-flan-t5-xl"): + super().__init__(gpu_number) + + # from lavis.models import load_model_and_preprocess + from transformers import Blip2Processor, Blip2ForConditionalGeneration + + # https://huggingface.co/models?sort=downloads&search=Salesforce%2Fblip2- + assert blip_v2_model_type in ['blip2-flan-t5-xxl', 'blip2-flan-t5-xl', 'blip2-opt-2.7b', 'blip2-opt-6.7b', + 'blip2-opt-2.7b-coco', 'blip2-flan-t5-xl-coco', 'blip2-opt-6.7b-coco'] + + with torch.cuda.device(self.dev): + max_memory = {gpu_number: torch.cuda.mem_get_info(self.dev)[0]} + + self.processor = Blip2Processor.from_pretrained(f"Salesforce/{blip_v2_model_type}") + # Device_map must be sequential for manual GPU selection + try: + self.model = Blip2ForConditionalGeneration.from_pretrained( + f"Salesforce/{blip_v2_model_type}", load_in_8bit=half_precision, + torch_dtype=torch.float16 if half_precision else "auto", + device_map="sequential", max_memory=max_memory + ) + except Exception as e: + # Clarify error message. The problem is that it tries to load part of the model to disk. + if "had weights offloaded to the disk" in e.args[0]: + extra_text = ' You may want to consider setting half_precision to True.' if half_precision else '' + raise MemoryError(f"Not enough GPU memory in GPU {self.dev} to load the model.{extra_text}") + else: + raise e + + self.qa_prompt = "Question: {} Short answer:" + self.caption_prompt = "a photo of" + self.half_precision = half_precision + self.max_words = 50 + + @torch.no_grad() + def caption(self, image, prompt=None): + inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.dev, torch.float16) + generation_output = self.model.generate(**inputs, length_penalty=1., num_beams=5, max_length=30, min_length=1, + do_sample=False, top_p=0.9, repetition_penalty=1.0, + num_return_sequences=1, temperature=1, + return_dict_in_generate=True, output_scores=True) + generated_text = [cap.strip() for cap in self.processor.batch_decode( + generation_output.sequences, skip_special_tokens=True)] + return generated_text, generation_output.sequences_scores.cpu().numpy().tolist() + + def pre_question(self, question): + # from LAVIS blip_processors + question = re.sub( + r"([.!\"()*#:;~])", + "", + question.lower(), + ) + question = question.rstrip(" ") + + # truncate question + question_words = question.split(" ") + if len(question_words) > self.max_words: + question = " ".join(question_words[: self.max_words]) + + return question + + @torch.no_grad() + def qa(self, image, question): + inputs = self.processor(images=image, text=question, return_tensors="pt", padding="longest").to(self.dev) + if self.half_precision: + inputs['pixel_values'] = inputs['pixel_values'].half() + generation_output = self.model.generate(**inputs, length_penalty=-1, num_beams=5, max_length=10, min_length=1, + do_sample=False, top_p=0.9, repetition_penalty=1.0, + num_return_sequences=1, temperature=1, + return_dict_in_generate=True, output_scores=True) + generated_text = self.processor.batch_decode(generation_output.sequences, skip_special_tokens=True) + return generated_text, generation_output.sequences_scores.cpu().numpy().tolist() + + def forward(self, image, question=None, task='caption'): + if not self.to_batch: + image, question, task = [image], [question], [task] + + if len(image) > 0 and 'float' in str(image[0].dtype) and image[0].max() <= 1: + image = [im * 255 for im in image] + + # Separate into qa and caption batches. + prompts_qa = [self.qa_prompt.format(self.pre_question(q)) for q, t in zip(question, task) if t == 'qa'] + images_qa = [im for i, im in enumerate(image) if task[i] == 'qa'] + images_caption = [im for i, im in enumerate(image) if task[i] == 'caption'] + + with torch.cuda.device(self.dev): + response_qa, scores_qa = self.qa(images_qa, prompts_qa) if len(images_qa) > 0 else ([], []) + response_caption, scores_caption = self.caption(images_caption) if len(images_caption) > 0 else ([], []) + + response = [] + for t in task: + if t == 'qa': + response.append([response_qa.pop(0), scores_qa.pop(0)]) + else: + response.append([response_caption.pop(0), scores_caption.pop(0)]) + + if not self.to_batch: + response = response[0] + return response + + +class XVLMModel(BaseModel): + name = 'xvlm' + + def __init__(self, gpu_number=0, path_checkpoint='pretrained_models/xvlm/retrieval_mscoco_checkpoint_9.pth'): + + from xvlm.xvlm import XVLMBase + from transformers import BertTokenizer + + super().__init__(gpu_number) + + image_res = 384 + self.max_words = 30 + config_xvlm = { + 'image_res': image_res, + 'patch_size': 32, + 'text_encoder': 'bert-base-uncased', + 'block_num': 9, + 'max_tokens': 40, + 'embed_dim': 256, + } + + vision_config = { + 'vision_width': 1024, + 'image_res': 384, + 'window_size': 12, + 'embed_dim': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32] + } + model = XVLMBase(config_xvlm, use_contrastive_loss=True, vision_config=vision_config) + checkpoint = torch.load(path_checkpoint, map_location='cpu') + state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint + msg = model.load_state_dict(state_dict, strict=False) + if len(msg.missing_keys) > 0: + print('XVLM Missing keys: ', msg.missing_keys) + + model = model.to(self.dev) + model.eval() + + self.model = model + self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + + normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + self.transform = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize((image_res, image_res), interpolation=Image.BICUBIC), + transforms.ToTensor(), + normalize, + ]) + + with open('useful_lists/random_negatives.txt') as f: + self.negative_categories = [x.strip() for x in f.read().split()] + + @staticmethod + def pre_caption(caption, max_words): + caption = re.sub( + r"([,.'!?\"()*#:;~])", + '', + caption.lower(), + ).replace('-', ' ').replace('/', ' ').replace('', 'person') + + caption = re.sub( + r"\s{2,}", + ' ', + caption, + ) + caption = caption.rstrip('\n') + caption = caption.strip(' ') + + # truncate caption + caption_words = caption.split(' ') + if len(caption_words) > max_words: + caption = ' '.join(caption_words[:max_words]) + + if not len(caption): + raise ValueError("pre_caption yields invalid text") + + return caption + + @torch.no_grad() + def score(self, images, texts): + + if isinstance(texts, str): + texts = [texts] + + if not isinstance(images, list): + images = [images] + + images = [self.transform(image) for image in images] + images = torch.stack(images, dim=0).to(self.dev) + + texts = [self.pre_caption(text, self.max_words) for text in texts] + text_input = self.tokenizer(texts, padding='longest', return_tensors="pt").to(self.dev) + + image_embeds, image_atts = self.model.get_vision_embeds(images) + text_ids, text_atts = text_input.input_ids, text_input.attention_mask + text_embeds = self.model.get_text_embeds(text_ids, text_atts) + + image_feat, text_feat = self.model.get_features(image_embeds, text_embeds) + logits = image_feat @ text_feat.t() + + return logits + + @torch.no_grad() + def binary_score(self, image, text, negative_categories): + # Compare with a pre-defined set of negatives + texts = [text] + negative_categories + sim = 100 * self.score(image, texts)[0] + res = F.softmax(torch.cat((sim[0].broadcast_to(1, sim.shape[0] - 1), + sim[1:].unsqueeze(0)), dim=0), dim=0)[0].mean() + return res + + def forward(self, image, text, task='score', negative_categories=None): + if task == 'score': + score = self.score(image, text) + else: # binary + score = self.binary_score(image, text, negative_categories=negative_categories) + return score.cpu() diff --git a/vision_processes.py b/vision_processes.py new file mode 100644 index 0000000000000000000000000000000000000000..d955feaa5a77c2ecfde21531c856455b9d2a973c --- /dev/null +++ b/vision_processes.py @@ -0,0 +1,76 @@ +import inspect +import traceback + +import torch + +import vision_models + +consumers = dict() + + +def load_models(): + global consumers + list_models = [m[1] for m in inspect.getmembers(vision_models, inspect.isclass) + if issubclass(m[1], vision_models.BaseModel) and m[1] != vision_models.BaseModel] + list_models.sort(key=lambda x: x.load_order) + print("-" * 10, "List models", list_models) + + counter_ = 0 + for model_class_ in list_models: + print("-" * 10, "Now loading {}:".format(model_class_)) + for process_name_ in model_class_.list_processes(): + consumers[process_name_] = make_fn(model_class_, process_name_, counter_) + counter_ += 1 + print("-" * 10, "Loading {} finished. Current gpu:".format(model_class_)) + print(torch.cuda.memory_summary()) + + print("-" * 10, "Model loading finished. Final gpu:") + print(torch.cuda.memory_summary()) + + +def make_fn(model_class, process_name, counter): + """ + model_class.name and process_name will be the same unless the same model is used in multiple processes, for + different tasks + """ + # We initialize each one on a separate GPU, to make sure there are no out of memory errors + num_gpus = torch.cuda.device_count() + gpu_number = counter % num_gpus + + model_instance = model_class(gpu_number=gpu_number) + + def _function(*args, **kwargs): + if process_name != model_class.name: + kwargs['process_name'] = process_name + + if model_class.to_batch: + # Batchify the input. Model expects a batch. And later un-batchify the output. + args = [[arg] for arg in args] + kwargs = {k: [v] for k, v in kwargs.items()} + + # The defaults that are not in args or kwargs, also need to listify + full_arg_spec = inspect.getfullargspec(model_instance.forward) + if full_arg_spec.defaults is None: + default_dict = {} + else: + default_dict = dict(zip(full_arg_spec.args[-len(full_arg_spec.defaults):], full_arg_spec.defaults)) + non_given_args = full_arg_spec.args[1:][len(args):] + non_given_args = set(non_given_args) - set(kwargs.keys()) + for arg_name in non_given_args: + kwargs[arg_name] = [default_dict[arg_name]] + + try: + out = model_instance.forward(*args, **kwargs) + if model_class.to_batch: + out = out[0] + except Exception as e: + print(f'Error in {process_name} model:', e) + traceback.print_exc() + out = None + return out + + return _function + + +def forward(model_name, *args, **kwargs): + return consumers[model_name](*args, **kwargs) diff --git a/xvlm/config_bert.json b/xvlm/config_bert.json new file mode 100644 index 0000000000000000000000000000000000000000..d8f56450ffd6fda4875af050e396fb2eba27ec1c --- /dev/null +++ b/xvlm/config_bert.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30522, + "fusion_layer": 6, + "encoder_width": 1024 +} diff --git a/xvlm/swin_transformer.py b/xvlm/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c1affc9a8695474e831ad060343c1988d750dc5f --- /dev/null +++ b/xvlm/swin_transformer.py @@ -0,0 +1,654 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- + +import numpy as np +from scipy import interpolate + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward(self, x, idx_to_group_img=None, image_atts=None, **kwargs): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + + x_cls = self.avgpool(x.transpose(1, 2)) # B C 1 + + if idx_to_group_img is None: + return torch.cat([x_cls.transpose(1, 2), x], dim=1) + else: + x_bs = torch.gather(x, dim=0, index=idx_to_group_img.view(-1, 1, 1).expand(-1, x.shape[1], x.shape[2])) + weights = image_atts[:, 1:].unsqueeze(2) # B L 1 + x_bs_cls = torch.sum((weights * x_bs).transpose(1, 2), dim=-1, keepdim=True) # B C 1 + x_bs_cls = x_bs_cls / torch.sum(weights.transpose(1, 2), dim=-1, keepdim=True) # avgpool + + return torch.cat([x_bs_cls.transpose(1, 2), x_bs], dim=1), \ + torch.cat([x_cls.transpose(1, 2), x], dim=1) + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name=''): + # from: https://github.com/microsoft/unilm/blob/8a0a1c1f4e7326938ea7580a00d56d7f17d65612/beit/run_class_finetuning.py#L348 + + # rel_pos_bias: relative_position_bias_table + src_num_pos, num_attn_heads = rel_pos_bias.size() + + num_extra_tokens = 0 + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) + if src_size != dst_size: + print("Position interpolate %s from %dx%d to %dx%d" % (param_name, src_size, src_size, dst_size, dst_size)) + + # extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + # rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + # print("Original positions = %s" % str(x)) + # print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() + f = interpolate.interp2d(x, y, z, kind='cubic') + all_rel_pos_bias.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + return rel_pos_bias \ No newline at end of file diff --git a/xvlm/vit.py b/xvlm/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..adddf9ef64a3100a879fd628d5b999b5b6f2b1af --- /dev/null +++ b/xvlm/vit.py @@ -0,0 +1,246 @@ +from functools import partial + +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_, DropPath +from timm.models.vision_transformer import PatchEmbed + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_gradients = None + self.attention_map = None + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, register_hook=False, image_atts=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if image_atts is not None: + attn += image_atts + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + # attn: (bs, num_heads, num_patches, num_patches) + # v: (bs, num_heads, num_patches, d) + # attn @ v: (bs, num_heads, num_patches, d) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, register_hook=False, image_atts=None): + x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook, image_atts=image_atts)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, local_attn_depth=0): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + self.num_patch_embed = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + self.num_pos_embed = self.num_patch_embed + 1 + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_pos_embed, embed_dim)) + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.depth = depth + self.local_attn_depth = local_attn_depth # do local attn from index=(depth - local_attn_depth) + + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, register_blk=-1, idx_to_group_img=None, image_atts=None): + + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:, :x.size(1), :] + x = self.pos_drop(x) + + do_gather = True if idx_to_group_img is not None else False + + if do_gather and (image_atts is not None): + full_atts = torch.ones(x.shape[:2], dtype=x.dtype).to(x.device) + image_atts_blk = torch.cat([image_atts, full_atts], dim=0) + + image_atts_blk = image_atts_blk.unsqueeze(1).unsqueeze(2) + image_atts_blk = (1.0 - image_atts_blk) * -10000.0 + else: + image_atts_blk = None + + for i, blk in enumerate(self.blocks): + if (self.local_attn_depth > 0) and (i >= self.depth - self.local_attn_depth): + if do_gather: + do_gather = False + + x_bs = torch.gather(x, dim=0, + index=idx_to_group_img.view(-1, 1, 1).expand(-1, x.shape[1], x.shape[2])) + x = torch.cat([x_bs, x], dim=0) + + x = blk(x, register_blk == i, image_atts=image_atts_blk) + + else: + x = blk(x, register_blk == i, image_atts=None) + + x = self.norm(x) + + if idx_to_group_img is not None: + bs = len(idx_to_group_img) + x_bs, x_fullatts = torch.split(x, [bs, x.size(0) - bs]) + return x_bs, x_fullatts + + return x + + +def interpolate_pos_embed(pos_embed_checkpoint, num_patches, num_extra_tokens=1): + # num_patches = visual_encoder.num_patch_embed + # num_extra_tokens = visual_encoder.num_pos_embed - visual_encoder.num_patch_embed + + # interpolate position embedding + embedding_size = pos_embed_checkpoint.shape[-1] + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + + if orig_size != new_size: + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + # print('reshape position embedding from %d to %d' % (orig_size ** 2, new_size ** 2)) + + return new_pos_embed + else: + return pos_embed_checkpoint diff --git a/xvlm/xbert.py b/xvlm/xbert.py new file mode 100644 index 0000000000000000000000000000000000000000..fb36b166ac5b37d96e71fb541bc180d409d7fb82 --- /dev/null +++ b/xvlm/xbert.py @@ -0,0 +1,2058 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model. """ + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from torch import Tensor, device +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +transformers.logging.set_verbosity_error() + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BertConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bert-base-uncased", + "bert-large-uncased", + "bert-base-cased", + "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-base-cased-finetuned-mrpc", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", + # See all BERT models at https://huggingface.co/models?filter=bert +] + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.fp16 = getattr(config, 'fp16', False) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + if not self.fp16: + query_layer = self.transpose_for_scores(mixed_query_layer) + else: + # to avoid gradient overflow + query_layer = self.transpose_for_scores(mixed_query_layer) / math.sqrt( + self.attention_head_size) # bsz, max_length, hidden_size + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + if not self.fp16: + attention_scores = attention_scores / math.sqrt(self.attention_head_size) # bsz, 12, max_length, max_length + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + + self.has_cross_attention = (layer_num >= config.fusion_layer) + if self.has_cross_attention: + self.layer_num = layer_num + self.crossattention = BertAttention(config, is_cross_attention=True) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if self.has_cross_attention: + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + if type(encoder_hidden_states) == list: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states[(self.layer_num - self.config.fusion_layer) % len(encoder_hidden_states)], + encoder_attention_mask[(self.layer_num - self.config.fusion_layer) % len(encoder_hidden_states)], + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] + + else: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multi_modal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + if mode == 'text': + start_layer = 0 + output_layer = self.config.fusion_layer + + elif mode == 'fusion': + start_layer = self.config.fusion_layer + output_layer = self.config.num_hidden_layers + + elif mode == 'multi_modal': + start_layer = 0 + output_layer = self.config.num_hidden_layers + else: + raise ValueError(f"mode {mode} is not supported") + + for i in range(start_layer, output_layer): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +@dataclass +class BertForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.BertForPreTraining`. + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +BERT_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + Parameters: + config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using :class:`~transformers.BertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multi_modal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, + device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class BertForPreTraining(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + next_sentence_label=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + Returns: + Example:: + >>> from transformers import BertTokenizer, BertForPreTraining + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = BertForPreTraining.from_pretrained('bert-base-uncased') + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class LabelSmoothSoftmaxCEV1(nn.Module): + ''' + This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients + ''' + + def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100): + super(LabelSmoothSoftmaxCEV1, self).__init__() + self.lb_smooth = lb_smooth + self.reduction = reduction + self.lb_ignore = ignore_index + self.log_softmax = nn.LogSoftmax(dim=1) + + def forward(self, logits, label): + ''' + Same usage method as nn.CrossEntropyLoss: + # >>> criteria = LabelSmoothSoftmaxCEV1() + # >>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half + # >>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t + # >>> loss = criteria(logits, lbs) + ''' + # overcome ignored label + logits = logits.float() # use fp32 to avoid nan + with torch.no_grad(): + num_classes = logits.size(1) + label = label.clone().detach() + ignore = label.eq(self.lb_ignore) + n_valid = ignore.eq(0).sum() + label[ignore] = 0 + lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes + lb_one_hot = torch.empty_like(logits).fill_( + lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach() + + logs = self.log_softmax(logits) + loss = -torch.sum(logs * lb_one_hot, dim=1) + loss[ignore] = 0 + if self.reduction == 'mean': + loss = loss.sum() / n_valid + if self.reduction == 'sum': + loss = loss.sum() + + return loss + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING +) +class BertLMHeadModel(BertPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config, label_smoothing=0.0): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + + self.cls = BertOnlyMLMHead(config) + + self.label_smoothing = label_smoothing + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=True, + reduction='mean', + mode='multi_modal', + return_logits=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + + if self.label_smoothing > 0: + loss_fct = LabelSmoothSoftmaxCEV1(lb_smooth=self.label_smoothing, reduction=reduction) + else: + loss_fct = CrossEntropyLoss(reduction=reduction) + + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + def _generate_no_beam_search( + self, + input_ids, + cur_len, + max_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + pad_token_id, + eos_token_ids, + batch_size, + **model_kwargs + ): + """ Generate sequences for each example without beam search (num_beams == 1). + All returned sequence are generated independantly. + """ + # current position / max lengths / length of generated sentences / unfinished sentences + unfinished_sents = [] + cur_unfinished = input_ids.new(batch_size).fill_(1) + + # log of scores for each sentence in the batch + logprobs = [] + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + outputs = self(**model_inputs, return_dict=True) + next_token_logits = outputs.logits[:, -1, :] + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + for i in range(batch_size): + for previous_token in set(input_ids[i].tolist()): + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if next_token_logits[i, previous_token] < 0: + next_token_logits[i, previous_token] *= repetition_penalty + else: + next_token_logits[i, previous_token] /= repetition_penalty + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + # Top-p/top-k filtering + next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) + # Sample + next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(1) + else: + # Greedy decoding + next_token = torch.argmax(next_token_logits, dim=-1) + + # Compute scores + _scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size, vocab_size) + _scores = torch.gather(_scores, -1, next_token.unsqueeze(-1)) # (batch_size, 1) + logprobs.append(_scores) # (batch_size, 1) + unfinished_sents.append(cur_unfinished) + + # update generations and finished sentences + tokens_to_add = next_token * cur_unfinished + pad_token_id * (1 - cur_unfinished) + input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + cur_len = cur_len + 1 + + for eos_token_id in eos_token_ids: + cur_unfinished = cur_unfinished.mul(tokens_to_add.ne(eos_token_id).long()) + + # stop when there is a in each sentence, or if we exceed the maximul length + if cur_unfinished.max() == 0: + break + + # add eos_token_ids to unfinished sentences + if cur_len == max_length: + input_ids[:, -1].masked_fill_(cur_unfinished.to(dtype=torch.bool), eos_token_ids[0]) + + logprobs = torch.cat(logprobs, dim=1) + unfinished_sents = torch.stack(unfinished_sents, dim=1).float() + sum_logprobs = (logprobs * unfinished_sents).sum(dim=1) + # return logprobs to keep consistent with beam search output + logprobs = sum_logprobs / unfinished_sents.sum(dim=1) + + # pad to the same length, otherwise DataParallel will give error + pad_len = max_length - input_ids.shape[1] + if pad_len > 0: + padding_ids = input_ids.new(batch_size, pad_len).fill_(pad_token_id) + input_ids = torch.cat([input_ids, padding_ids], dim=1) + + return input_ids, logprobs + + +def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): + """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) +class BertForMaskedLM(BertPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def gather_seq_out_by_pos(self, seq, pos): + return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1))) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multi_modal', + return_logits=False, + masked_pos=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_embeds=encoder_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + + if masked_pos is not None: + # sequence_output, (bs, len, 768) + # masked_pos, (bs, n_mask) + sequence_output = self.gather_seq_out_by_pos(sequence_output, masked_pos) + # sequence_output, (bs, n_mask, 768) + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top. """, + BERT_START_DOCSTRING, +) +class BertForNextSentencePrediction(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see ``input_ids`` docstring). Indices should be in ``[0, 1]``: + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + Returns: + Example:: + >>> from transformers import BertTokenizer, BertForNextSentencePrediction + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased') + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + BERT_START_DOCSTRING, +) +class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BERT_START_DOCSTRING, +) +class BertForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., + num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See + :obj:`input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BERT_START_DOCSTRING, +) +class BertForTokenClassification(BertPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BERT_START_DOCSTRING, +) +class BertForQuestionAnswering(BertPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/xvlm/xvlm.py b/xvlm/xvlm.py new file mode 100644 index 0000000000000000000000000000000000000000..e11d7bd57ebf22586efea0659ad19bd0af458e3d --- /dev/null +++ b/xvlm/xvlm.py @@ -0,0 +1,401 @@ +# Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) +# Github: https://github.com/zengyan-97/X-VLM +# Copyright (c) 2022, ByteDance Inc. +# All rights reserved. + +import json +import os + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from xvlm.swin_transformer import SwinTransformer, interpolate_relative_pos_embed +from xvlm.vit import interpolate_pos_embed +from xvlm.xbert import BertConfig, BertForMaskedLM, BertModel + + +def read_json(rpath): + with open(rpath, 'r') as f: + return json.load(f) + + +class AllGather(torch.autograd.Function): + """An autograd function that performs allgather on a tensor.""" + + @staticmethod + def forward(ctx, tensor, rank, world_size): + output = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(output, tensor) + ctx.rank = rank + ctx.batch_size = tensor.shape[0] + return torch.cat(output, 0) + + @staticmethod + def backward(ctx, grad_output): + return ( + grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)], + None, + None + ) + + +allgather = AllGather.apply + + +def build_vision_encoder(vision_config, load_params=False): + """ + Args: + load_params: False when building fine-tuning models + """ + vision_width = vision_config['vision_width'] + + vision_encoder = SwinTransformer(img_size=vision_config['image_res'], + patch_size=4, + in_chans=3, + embed_dim=vision_config['embed_dim'], + depths=vision_config['depths'], + num_heads=vision_config['num_heads'], + window_size=vision_config['window_size'], + mlp_ratio=4., + qkv_bias=True, + drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + use_checkpoint=False) + + if load_params: + # download from https://github.com/microsoft/Swin-Transformer + state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] + + for k in list(state_dict.keys()): + if 'relative_position_bias_table' in k: + dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2 + state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) + elif ('relative_position_index' in k) or ('attn_mask' in k): + del state_dict[k] + + if load_params: + print("### Load ViT: ", flush=True) + msg = vision_encoder.load_state_dict(state_dict, strict=False) + print("missing_keys: ", msg.missing_keys) + print("unexpected_keys: ", msg.unexpected_keys) + + return vision_encoder, vision_width + + +def build_text_encoder(config, vision_width, load_text_params=False, use_mlm_loss=False, config_text=None): + init_params = [] # train from scratch with larger lr + + config_text = BertConfig.from_json_file('xvlm/config_bert.json') + config_text.encoder_width = vision_width + + if use_mlm_loss: # for pre-training, load_text_params by default (otherwise notimplemented) + assert load_text_params is True + if ('accelerator' in config.keys()) and (config['accelerator']['FP16_OPT_LEVEL'] != 'O0'): + config_text.fp16 = True # will use some operations to avoid gradient overflow + + text_encoder, msg = BertForMaskedLM.from_pretrained(config['text_encoder'], config=config_text, + output_loading_info=True) + + print("### Load BERT: ") + for k, v in msg.items(): + print(f"{k}: {sorted(v)}") + + init_params.extend(['text_encoder.' + n for n in msg['missing_keys']]) # of cross attention + + if ('load_bertL_by_sep' in config.keys()) and config['load_bertL_by_sep']: + state_dict = torch.load(os.path.join(config['text_encoder'], 'pytorch_model.bin')) + for idx, i_layer in enumerate([13, 15, 17, 19, 21, 23]): + state_dict_i = {k[22:]: v for k, v in state_dict.items() if f'layer.{i_layer}' in k} + msg = text_encoder.bert.encoder.layer[config_text.fusion_layer + idx]. \ + load_state_dict(state_dict_i, strict=False) + print(f"### Load {i_layer} to {config_text.fusion_layer + idx}-layer: {msg}") + + else: # for fine-tuning, not load_text_params by default + assert load_text_params is False + + text_encoder = BertModel(config=config_text, add_pooling_layer=False) + + return text_encoder, init_params + + +def build_mlp(input_dim, output_dim): + return nn.Sequential( + nn.Linear(input_dim, input_dim * 2), + nn.LayerNorm(input_dim * 2), + nn.GELU(), + nn.Linear(input_dim * 2, output_dim) + ) + + +def load_pretrained(ckpt_rpath, config, is_eval=False, load_text=False): + checkpoint = torch.load(ckpt_rpath, map_location='cpu') + state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint + + if is_eval: + return state_dict + + num_patches = (config['image_res'] // config['patch_size']) ** 2 + + print("### Loading pretrained vision encoder", flush=True) + if config['use_clip_vit']: + del state_dict['vision_encoder.position_ids'] + pos_embed_reshaped = interpolate_pos_embed(state_dict['vision_encoder.pos_embed.weight'].unsqueeze(dim=0), + num_patches=num_patches, num_extra_tokens=1) + state_dict['vision_encoder.pos_embed.weight'] = pos_embed_reshaped.squeeze(dim=0) + + elif config['use_swin']: + + window_size = read_json(config['vision_config'])['window_size'] + + for k in list(state_dict.keys()): + if 'relative_position_bias_table' in k: + dst_num_pos = (2 * window_size - 1) ** 2 + state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) + elif ('relative_position_index' in k) or ('attn_mask' in k): + del state_dict[k] + + else: + pos_embed_reshaped = interpolate_pos_embed(state_dict['vision_encoder.pos_embed'], + num_patches=num_patches, num_extra_tokens=1) + state_dict['vision_encoder.pos_embed'] = pos_embed_reshaped + + if load_text: + print("### Loading pretrained text encoder", flush=True) + for key in list(state_dict.keys()): + if 'text_encoder.' in key: + if 'bert.' in key: + encoder_key = key.replace('bert.', '') + state_dict[encoder_key] = state_dict[key] + del state_dict[key] + + return state_dict + + +class XVLMBase(nn.Module): + def __init__(self, config=None, load_vision_params=False, load_text_params=False, + use_contrastive_loss=False, use_matching_loss=False, use_mlm_loss=False, use_bbox_loss=False, + config_text=None, vision_config=None): + super().__init__() + self.init_params = [] # train from scratch with larger lr + + self.vision_encoder, vision_width = build_vision_encoder(vision_config, load_params=load_vision_params) + + self.text_encoder, init_params = build_text_encoder(vision_config, vision_width=vision_width, + load_text_params=load_text_params, + use_mlm_loss=use_mlm_loss, + config_text=config_text) # text & cross-modal + self.init_params.extend(init_params) + + self.vision_width = vision_width + self.text_width = self.text_encoder.config.hidden_size # i.e. cross_width + + if use_contrastive_loss: + self.embed_dim = config['embed_dim'] + self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) + self.text_proj = nn.Linear(self.text_width, self.embed_dim) + self.init_params.extend(['vision_proj.' + n for n, _ in self.vision_proj.named_parameters()]) + self.init_params.extend(['text_proj.' + n for n, _ in self.text_proj.named_parameters()]) + + if use_matching_loss: + self.itm_head = build_mlp(input_dim=self.text_width, output_dim=2) + self.init_params.extend(['itm_head.' + n for n, _ in self.itm_head.named_parameters()]) + + if use_bbox_loss: + self.bbox_head = build_mlp(input_dim=self.text_width, output_dim=4) + self.init_params.extend(['bbox_head.' + n for n, _ in self.bbox_head.named_parameters()]) + + def load_pretrained(self, ckpt_rpath, config, is_eval=False): + state_dict = load_pretrained(ckpt_rpath, config, is_eval=is_eval, load_text=True) + msg = self.load_state_dict(state_dict, strict=False) + print('load checkpoint from %s' % ckpt_rpath) + print("missing_keys: ", [p for p in msg.missing_keys if 'vision_encoder' not in p]) + print("unexpected_keys: ", msg.unexpected_keys) + + def get_vision_embeds(self, image, image_atts=None, idx_to_group_img=None): + """ + vision_embeds: cls + patch embeds + """ + if idx_to_group_img is None: + image_embeds = self.vision_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + return image_embeds, image_atts + + else: + if image_atts is None: + image_embeds_fullatts = self.vision_encoder(image) + image_embeds_fullatts = torch.gather(image_embeds_fullatts, dim=0, + index=idx_to_group_img.view(-1, 1, 1).expand( + -1, image_embeds_fullatts.shape[1], + image_embeds_fullatts.shape[2])) + + image_atts = torch.ones(image_embeds_fullatts.size()[:-1], dtype=torch.long).to(image.device) + + return image_embeds_fullatts, image_atts + + else: + assert image_atts.size(0) == idx_to_group_img.size(0) # bsz + image_embeds, image_embeds_fullatts = \ + self.vision_encoder(image, idx_to_group_img=idx_to_group_img, image_atts=image_atts) + + image_embeds_fullatts = torch.gather(image_embeds_fullatts, dim=0, + index=idx_to_group_img.view(-1, 1, 1).expand( + -1, image_embeds_fullatts.shape[1], + image_embeds_fullatts.shape[2])) + + return image_embeds, image_atts, image_embeds_fullatts + + def get_text_embeds(self, text_ids, text_atts): + encoder = self.text_encoder.bert if hasattr(self.text_encoder, 'bert') else self.text_encoder + return encoder(text_ids, attention_mask=text_atts, return_dict=True, mode='text').last_hidden_state + + def get_cross_embeds(self, image_embeds, image_atts, text_ids=None, text_embeds=None, text_atts=None): + assert text_atts is not None + + encoder = self.text_encoder.bert if hasattr(self.text_encoder, 'bert') else self.text_encoder + + if text_embeds is not None: + return encoder(encoder_embeds=text_embeds, + attention_mask=text_atts, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + mode='fusion', + ).last_hidden_state + elif text_ids is not None: + return encoder(text_ids, + attention_mask=text_atts, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ).last_hidden_state + else: + raise ValueError + + def get_features(self, image_embeds=None, text_embeds=None): + if image_embeds is None: + return F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1) + elif text_embeds is None: + return F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) + else: + return F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1), \ + F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1) + + def get_contrastive_loss(self, image_feat, text_feat, idx=None): + """ + Args: + image_feat, text_feat: normalized + + Returns: contrastive loss + + """ + assert image_feat.size(-1) == self.embed_dim + assert text_feat.size(-1) == self.embed_dim + + image_feat_all = allgather(image_feat, torch.distributed.get_rank(), torch.distributed.get_world_size()) + text_feat_all = allgather(text_feat, torch.distributed.get_rank(), torch.distributed.get_world_size()) + logits = image_feat_all @ text_feat_all.t() / self.temp + + bsz = image_feat_all.shape[0] + + if idx is None: + labels = torch.arange(bsz, device=image_feat.device) + loss_i2t = F.cross_entropy(logits, labels) + loss_t2i = F.cross_entropy(logits.t(), labels) + + else: + idx = idx.view(-1, 1) + assert idx.size(0) == image_feat.size(0) + idx_all = allgather(idx, torch.distributed.get_rank(), torch.distributed.get_world_size()) + pos_idx = torch.eq(idx_all, idx_all.t()).float() + labels = pos_idx / pos_idx.sum(1, keepdim=True) + + loss_i2t = -torch.sum(F.log_softmax(logits, dim=1) * labels, dim=1).mean() + loss_t2i = -torch.sum(F.log_softmax(logits.t(), dim=1) * labels, dim=1).mean() + + return (loss_i2t + loss_t2i) / 2 + + def get_matching_loss(self, image_embeds, image_atts, image_feat, text_embeds, text_atts, text_feat, idx=None): + """ + Matching Loss with hard negatives + """ + bs = image_embeds.size(0) + with torch.no_grad(): + sim_i2t = image_feat @ text_feat.t() / self.temp + sim_t2i = text_feat @ image_feat.t() / self.temp + + weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-5 + weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-5 + + if idx is None: + weights_i2t.fill_diagonal_(0) + weights_t2i.fill_diagonal_(0) + else: + idx = idx.view(-1, 1) + assert idx.size(0) == bs + mask = torch.eq(idx, idx.t()) + weights_i2t.masked_fill_(mask, 0) + weights_t2i.masked_fill_(mask, 0) + + image_embeds_neg = [] + image_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_atts_neg.append(image_atts[neg_idx]) + + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + image_atts_neg = torch.stack(image_atts_neg, dim=0) + + text_embeds_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_embeds_neg.append(text_embeds[neg_idx]) + text_atts_neg.append(text_atts[neg_idx]) + + text_embeds_neg = torch.stack(text_embeds_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) + text_atts_all = torch.cat([text_atts, text_atts_neg], dim=0) + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) + image_atts_all = torch.cat([image_atts_neg, image_atts], dim=0) + + cross_pos = self.get_cross_embeds(image_embeds, image_atts, text_embeds=text_embeds, text_atts=text_atts)[:, 0, + :] + cross_neg = self.get_cross_embeds(image_embeds_all, image_atts_all, text_embeds=text_embeds_all, + text_atts=text_atts_all)[:, 0, :] + + output = self.itm_head(torch.cat([cross_pos, cross_neg], dim=0)) + itm_labels = torch.cat([torch.ones(bs, dtype=torch.long), + torch.zeros(2 * bs, dtype=torch.long)], dim=0).to(image_embeds.device) + + return F.cross_entropy(output, itm_labels) + + def get_mlm_loss(self, text_ids_masked, text_atts, image_embeds, image_atts, masked_pos, masked_ids): + return self.text_encoder(text_ids_masked, + attention_mask=text_atts, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + labels=masked_ids, + masked_pos=masked_pos).loss + + def predict_bbox(self, image_embeds, text_embeds, text_atts): + """ + Args: + image_embeds: encoding full images + + Returns: + output_coord: bsz, 4 + """ + assert image_embeds.size(0) == text_embeds.size(0) + + output_cls = self.get_cross_embeds(image_embeds, torch.ones(image_embeds.shape[:2]).to(image_embeds.device), + text_embeds=text_embeds, text_atts=text_atts)[:, 0, :] + output_coord = self.bbox_head(output_cls).sigmoid() + + return output_coord