diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 1a9cf115bb98974866dc6941312214fe319b246c..3f966ce54bc54f87e3af8e1290d7cc5510f927fa 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,438 @@ ---- -title: Starvector 1b Im2svg -emoji: 😻 -colorFrom: pink -colorTo: indigo -sdk: docker -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +
+

💫 StarVector: Generating Scalable Vector Graphics Code from Images and Text

+ starvector + + + arXiv + + + Website + + + HF Models: StarVector + + + HF Models: StarVector + + + HF Dataset: SVG-Stack + + + HF Dataset: SVG-Bench + + +
+ Juan A. Rodriguez, + Abhay Puri, + Shubham Agarwal, + Issam H. Laradji, + Pau Rodriguez, + David Vazquez, + Chris Pal, + Marco Pedersoli +
+ +
+ +## 🔥 News +- March 2025: **StarVector Accepted at CVPR 2025**, + - StarVector has been accepted at CVPR 2025! [[Link](https://arxiv.org/abs/2312.11556)] + - Check out our website for more information [[Link](https://starvector.github.io/)] + - StarVector models are now available on Hugging Face Model Hub! [[Link](https://huggingface.co/starvector/starvector-1b-im2svg)] [[Link](https://huggingface.co/starvector/starvector-8b-im2svg)] + - SVGBench and SVG-Stack datasets are now available on Hugging Face Datasets Hub! [[Link](https://huggingface.co/datasets/starvector/svg-bench)] [[Link](https://huggingface.co/datasets/starvector/svg-stack)] + +## 🚀 Introduction +StarVector is a multimodal vision-language model for Scalable Vector Graphics (SVG) generation. It can be used to perform image2SVG and text2SVG generation. We pose image generation as a code generation task, using the power of multimodal VLMs + +
+ starvector +
+ +> **Abstract**: Scalable Vector Graphics (SVGs) are vital for modern image rendering due to their scalability and versatility. Previous SVG generation methods have focused on curve-based vectorization, lacking semantic understanding, often producing artifacts, and struggling with SVG primitives beyond \textit{path} curves. To address these issues, we introduce StarVector, a multimodal large language model for SVG generation. It performs image vectorization by understanding image semantics and using SVG primitives for compact, precise outputs. Unlike traditional methods, StarVector works directly in the SVG code space, leveraging visual understanding to apply accurate SVG primitives. To train StarVector, we create SVG-Stack, a diverse dataset of 2M samples that enables generalization across vectorization tasks and precise use of primitives like ellipses, polygons, and text. We address challenges in SVG evaluation, showing that pixel-based metrics like MSE fail to capture the unique qualities of vector graphics. We introduce SVG-Bench, a benchmark across 10 datasets, and 3 tasks: Image-to-SVG, Text-to-SVG generation, and diagram generation. Using this setup, StarVector achieves state-of-the-art performance, producing more compact and semantically rich SVGs. + +### Multimodal Architecture + +StarVector uses a multimodal architecture to process images and text. When performing Image-to-SVG (or image vectorization), the image is projected into visual tokens, and SVG code is generated. When performing Text-to-SVG, the model only recieves the text instruction (no image is provided), and a novel SVG is created. The LLM is based of StarCoder, which we leverage to transfer coding skills to SVG generation. + +
+ starvector +
+ +## 📖 Table of Contents +- [💿 Installation](#installation) +- [🏎️ Quick Start - Image2SVG Generation](#quick-start---image2svg-generation) +- [🎨 Models](#models) +- [📊 Datasets](#datasets---svg-bench) +- [🏋️‍♂️ Training](#training) +- [🏆 Evaluation on SVG-Bench](#validation-on-svg-benchmarks-svg-bench) +- [🧩 Demo](#starvector-demo) +- [📚 Citation](#citation) +- [📝 License](#license) + + +## Installation + +1. Clone this repository and navigate to star-vector folder +```bash +git clone https://github.com/joanrod/star-vector.git +cd star-vector +``` + +2. Install Package +```Shell +conda create -n starvector python=3.11.3 -y +conda activate starvector +pip install --upgrade pip # enable PEP 660 support +pip install -e . +``` + +3. Install additional packages for training +``` +pip install -e ".[train]" +``` + +### Upgrade to latest code base + +```Shell +git pull +pip install -e . +``` + +## Quick Start - Image2SVG Generation + +```Python +from PIL import Image +from starvector.model.starvector_arch import StarVectorForCausalLM +from starvector.data.util import process_and_rasterize_svg + +model_name = "starvector/starvector-8b-im2svg" + +starvector = StarVectorForCausalLM.from_pretrained(model_name) + +starvector.cuda() +starvector.eval() + +image_pil = Image.open('assets/examples/sample-0.png') +image = starvector.process_images([image_pil])[0].cuda() +batch = {"image": image} + +raw_svg = starvector.generate_im2svg(batch, max_length=1000)[0] +svg, raster_image = process_and_rasterize_svg(raw_svg) +``` + +### Use it from HuggingFace AutoModel + +```Python +from PIL import Image +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor +from starvector.data.util import process_and_rasterize_svg +import torch + +model_name = "starvector/starvector-8b-im2svg" + +starvector = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True) +processor = starvector.model.processor +tokenizer = starvector.model.svg_transformer.tokenizer + +starvector.cuda() +starvector.eval() + +image_pil = Image.open('assets/examples/sample-18.png') + +image = processor(image_pil, return_tensors="pt")['pixel_values'].cuda() +if not image.shape[0] == 1: + image = image.squeeze(0) +batch = {"image": image} + +raw_svg = starvector.generate_im2svg(batch, max_length=4000)[0] +svg, raster_image = process_and_rasterize_svg(raw_svg) +``` + + +## Models + +We provide [Hugging Face 🤗 model checkpoints](https://huggingface.co/collections/starvector/starvector-models-6783b22c7bd4b43d13cb5289) for image2SVG vectorization, for 💫 StarVector-8B and 💫 StarVector-1B. These are the results on SVG-Bench, using the DinoScore metric. + +| Method | SVG-Stack | SVG-Fonts | SVG-Icons | SVG-Emoji | SVG-Diagrams | +|---------------|-----------|-----------|-----------|-----------|--------------| +| AutoTrace | 0.942 | 0.954 | 0.946 | 0.975 | 0.874 | +| Potrace | 0.898 | 0.967 | 0.972 | 0.882 | 0.875 | +| VTracer | 0.954 | 0.964 | 0.940 | 0.981 | 0.882 | +| Im2Vec | 0.692 | 0.733 | 0.754 | 0.732 | - | +| LIVE | 0.934 | 0.956 | 0.959 | 0.969 | 0.870 | +| DiffVG | 0.810 | 0.821 | 0.952 | 0.814 | 0.822 | +| GPT-4-V | 0.852 | 0.842 | 0.848 | 0.850 | - | +| 💫 StarVector-1B (🤗 [Link](https://huggingface.co/starvector/starvector-1b-im2svg)) | 0.926 | 0.978 | 0.975 | 0.929 | 0.943 | +| 💫 StarVector-8B (🤗 [Link](https://huggingface.co/starvector/starvector-8b-im2svg)) | **0.966** | **0.982** | **0.984** | **0.981** | **0.959** | + +*Note*: StarVector models will not work for natural images or illustrations, as they have not been trained on those images. They excel in vectorizing icons, logotypes, technical diagrams, graphs, and charts. + +## Datasets - SVG-Bench +SVG-Bench is a benchmark for evaluating SVG generation models. It contains 10 datasets, and 3 tasks: Image-to-SVG, Text-to-SVG, and Diagram-to-SVG. + +See our [Huggingface 🤗 Dataset Collection](https://huggingface.co/collections/starvector/starvector-svg-datasets-67811204a76475be4dd66d09) + +| Dataset | Train | Val | Test | Token Length | SVG Primitives | Annotation | +|-----------------|--------|-------|------|------------------|----------------|----------------| +| SVG-Stack (🤗 [Link](https://huggingface.co/datasets/starvector/svg-stack)) | 2.1M | 108k | 5.7k | 1,822 ± 1,808 | All | [Captions](https://huggingface.co/datasets/starvector/text2svg-stack) | +| SVG-Stack_sim (🤗 [Link](https://huggingface.co/datasets/starvector/svg-stack-simple)) | 601k | 30.1k | 1.5k | 2k ± 918 | Vector path | - | +| SVG-Diagrams (🤗 [Link](https://huggingface.co/datasets/starvector/svg-diagrams)) | - | - | 472 | 3,486 ± 1,918 | All | - | +| SVG-Fonts (🤗 [Link](https://huggingface.co/datasets/starvector/svg-fonts)) | 1.8M | 91.5k | 4.8k | 2,121 ± 1,868 | Vector path | Font letter | +| SVG-Fonts_sim (🤗 [Link](https://huggingface.co/datasets/starvector/svg-fonts-simple)) | 1.4M | 71.7k | 3.7k | 1,722 ± 723 | Vector path | Font letter | +| SVG-Emoji (🤗 [Link](https://huggingface.co/datasets/starvector/svg-emoji)) | 8.7k | 667 | 668 | 2,551 ± 1,805 | All | - | +| SVG-Emoji_sim (🤗 [Link](https://huggingface.co/datasets/starvector/svg-emoji-simple)) | 580 | 57 | 96 | 2,448 ± 1,026 | Vector Path | - | +| SVG-Icons (🤗 [Link](https://huggingface.co/datasets/starvector/svg-icons)) | 80.4k | 6.2k | 2.4k | 2,449 ± 1,543 | Vector path | - | +| SVG-Icons_sim (🤗 [Link](https://huggingface.co/datasets/starvector/svg-icons-simple)) | 80,435 | 2,836 | 1,277| 2,005 ± 824 | Vector path | - | +| SVG-FIGR (🤗 [Link](https://huggingface.co/datasets/starvector/FIGR-SVG)) | 270k | 27k | 3k | 5,342 ± 2,345 | Vector path | Class, Caption | + + +>We offer a summary of statistics about the datasets used in our training and evaluation experiments. This datasets are included in SVG-Bench. The subscript _sim_ stands for the simplified version of the dataset, as required by some baselines. + +## Training + +### Confirm dependencies are installed + +```bash +pip install -e ".[train]" +``` + +### Set environment variables +We recommend setting the following environment variables: + +```bash + export HF_HOME= + export HF_TOKEN= + export WANDB_API_KEY= + export OUTPUT_DIR= +``` + +cd the root of the repository. + +```Shell +cd star-vector +``` + +### Image2SVG Pretraining (Stage 1) + +We have different training approaches for StarVector-1B and StarVector-8B. StarVector-1B can be trained using Deepspeed, while StarVector-8B requires FSDP. + +#### StarVector-1B Training + +You can use the following command to train StarVector-1B on SVG-Stack for the Image2SVG vectorization task, using Deepspeed and Accelerate + +```bash +# StarVector-1B +accelerate launch --config_file configs/accelerate/deepspeed-8-gpu.yaml starvector/train/train.py config=configs/models/starvector-1b/im2svg-stack.yaml +``` + +#### StarVector-8B Training + +You can use the following command to train StarVector-8B on SVG-Stack for the Image2SVG vectorization task, using FSDP and Accelerate. We provide the torchrun command to support multi-nodes and multi-GPUs. + +```bash +# StarVector-8B +torchrun \ + --nproc-per-node=8 \ + --nnodes=1 \ + starvector/train/train.py \ + config=configs/models/starvector-8b/im2svg-stack.yaml +``` + + +### Finetuning StarVector (Stage 2) + +After pretraining StarVector on image vectorization, we finetune it on additional SVG tasks like Text2SVG, and SVG-Bench datasets. + +#### Text2SVG Finetuning + +```bash +# StarVector-1B +accelerate launch --config_file config/accelerate/deepspeed-8-gpu.yaml starvector/train/train.py config=configs/models/starvector-1b/text2svg-stack.yaml + +# StarVector-8B +torchrun \ + --nproc-per-node=8 \ + --nnodes=1 \ + starvector/train/train.py \ + config=configs/models/starvector-8b/text2svg-stack.yaml +``` + +#### SVG-Bench Finetuning + +```bash +# StarVector-1B +accelerate launch --config_file config/accelerate/deepspeed-8-gpu.yaml starvector/train/train.py config=configs/models/starvector-1b/im2svg-{fonts,icons,emoji}.yaml + +# StarVector-8B +torchrun \ + --nproc-per-node=8 \ + --nnodes=1 \ + starvector/train/train.py \ + config=configs/models/starvector-8b/im2svg-{fonts,icons,emoji}.yaml +``` + +We also provide shell scripts in `scripts/train/*` + +## Validation on SVG Benchmarks (⭐ SVG-Bench) + +We validate StarVector on ⭐ SVG-Bench Benchmark. We provide the SVGValidator class that allows you to run StarVector using **1) the HuggingFace generation backend** or **2) the VLLM backend**. The later is substantially faster thanks to the use of Paged Attention. + +### HuggingFace Generation Backend +Let's start with the evaluation for StarVector-1B and StarVector-8B on SVG-Stack, using the HuggingFace generation backend (StarVectorHFAPIValidator). To override the input arguments, you can add cli args following the yaml file structure. + +```bash +# StarVector-1B on SVG-Stack, using the HuggingFace backend +python starvector/validation/validate.py \ +config=configs/generation/hf/starvector-1b/im2svg.yaml \ +dataset.name=starvector/svg-stack + +# StarVector-8B on SVG-Stack, using the vanilla HuggingFace generation API +python starvector/validation/validate.py \ +config=configs/generation/hf/starvector-8b/im2svg.yaml \ +dataset.name=starvector/svg-stack +``` + +### vLLM Backend + +For using the vLLM backend (StarVectorVLLMAPIValidator), first install our StarVector fork of VLLM, [here](https://github.com/starvector/vllm). + +```bash +git clone https://github.com/starvector/vllm.git +cd vllm +pip install -e . +``` + +Then, launch the using the vllm config file (it uses StarVectorVLLMValidator): + +```bash +# StarVector-1B +python starvector/validation/validate.py \ +config=configs/generation/vllm/starvector-1b/im2svg.yaml \ +dataset.name=starvector/svg-stack + +# StarVector-8B +python starvector/validation/validate.py \ +config=configs/generation/vllm/starvector-8b/im2svg.yaml \ +dataset.name=starvector/svg-stack +``` + +#### Generate using Temperature Sweep +Temperature sweep is an evaluation technique where we: +1. Generate multiple SVG candidates using different temperature values (controlling randomness in generation) +2. Evaluate each candidate using the DinoScore metric +3. Select the best performing SVG as the final output + +This approach improves result quality by exploring multiple generation possibilities, though it requires more computation time. + + +```bash +# StarVector-1B (vLLM) +python starvector/validation/run_validator.py \ +config=configs/generation/vllm/starvector-1b/im2svg.yaml \ +dataset.name=svg-stack \ +generation_params.generation_sweep=True \ +generation_params.num_generations_different_temp=5 \ +generation_params.min_temperature=0.0 \ +generation_params.max_temperature=0.5 + +# StarVector-8B (vLLM) +python starvector/validation/run_validator.py \ +config=configs/generation/vllm/starvector-8b/im2svg.yaml \ +dataset.name=svg-stack \ +generation_params.generation_sweep=True \ +generation_params.num_generations_different_temp=10 \ +generation_params.min_temperature=0.0 \ +generation_params.max_temperature=0.5 + +``` + +We provide evaluation scripts in `scripts/eval/*` + + +## StarVector Demo + +The demo provides two options for converting images to SVG code: +1. HuggingFace generation functionality +2. VLLM (recommended) - offers faster generation speed + +### Option 1: HuggingFace Generation with Gradio Web UI + +We provide a Gradio web UI for you to play with our model. + +#### Launch a controller +```Shell +python -m starvector.serve.controller --host 0.0.0.0 --port 10000 +``` + +#### Launch a gradio web server. +```Shell +python -m starvector.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --port 7000 +``` +You just launched the Gradio web interface. Now, you can open the web interface with the URL printed on the screen. You may notice that there is no model in the model list. Do not worry, as we have not launched any model worker yet. It will be automatically updated when you launch a model worker. + +#### Launch a model worker + +This is the actual *worker* that performs the inference on the GPU. Each worker is responsible for a single model specified in `--model-path`. + +```Shell +python -m starvector.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path joanrodai/starvector-1.4b +``` +Wait until the process finishes loading the model and you see "Uvicorn running on ...". Now, refresh your Gradio web UI, and you will see the model you just launched in the model list. + +You can launch as many workers as you want, and compare between different model checkpoints in the same Gradio interface. Please keep the `--controller` the same, and modify the `--port` and `--worker` to a different port number for each worker. + + +```Shell +vllm serve starvector/starvector-8b-im2svg --chat-template configs/chat-template.jinja --trust-remote-code --port 8001 --max-model-len 16000 + +python -m starvector.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port --worker http://localhost: --model-path +``` + +#### Option 2: Launch VLLM + +0. Remember to clone the starvector/vllm fork (it has modifications for starvector). + +```Shell +git clone https://github.com/starvector/vllm.git +cd vllm +pip install -e . +``` + +1. Call this to launch the VLLM endpoint + + +```Shell +vllm serve starvector/starvector-1b-im2svg --chat-template configs/chat-template.jinja --trust-remote-code --port 8000 --max-model-len 8192 +``` + +2. Create the demo for VLLM + +```Shell +python -m starvector.serve.vllm_api_gradio.controller --host 0.0.0.0 --port 10000 +python -m starvector.serve.vllm_api_gradio.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --port 7000 +python -m starvector.serve.vllm_api_gradio.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-name starvector/starvector-1b-im2svg --vllm-base-url http://localhost:8000 +``` + +3. Add more models by serving them with VLLM and calling a new model worker + +```Shell +vllm serve starvector/starvector-8b-im2svg --chat-template configs/chat-template.jinja --trust-remote-code --port 8001 --max-model-len 16384 + +python -m starvector.serve.vllm_api_gradio.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40001 --worker http://localhost:40001 --model-name starvector/starvector-8b-im2svg --vllm-base-url http://localhost:8001 +``` + +## Citation +``` +@misc{rodriguez2024starvector, + title={StarVector: Generating Scalable Vector Graphics Code from Images and Text}, + author={Juan A. Rodriguez and Abhay Puri and Shubham Agarwal and Issam H. Laradji and Pau Rodriguez and Sai Rajeswar and David Vazquez and Christopher Pal and Marco Pedersoli}, + year={2024}, + eprint={2312.11556}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2312.11556}, +} +``` + +## License +This project is licensed under the Apache License, Version 2.0 - see the [LICENSE](LICENSE) file for details. diff --git a/assets/examples/sample-0.png b/assets/examples/sample-0.png new file mode 100644 index 0000000000000000000000000000000000000000..b4b8b2116ca06f443dbaa1447df2e065a00ee8d8 Binary files /dev/null and b/assets/examples/sample-0.png differ diff --git a/assets/examples/sample-1.png b/assets/examples/sample-1.png new file mode 100644 index 0000000000000000000000000000000000000000..97a1895ed9729527626f9d57e99c33fb1d547ba2 Binary files /dev/null and b/assets/examples/sample-1.png differ diff --git a/assets/examples/sample-15.png b/assets/examples/sample-15.png new file mode 100644 index 0000000000000000000000000000000000000000..7f7ae2ba7b6032dd244b8c0f978bd9a8e051c0f6 Binary files /dev/null and b/assets/examples/sample-15.png differ diff --git a/assets/examples/sample-16.png b/assets/examples/sample-16.png new file mode 100644 index 0000000000000000000000000000000000000000..1b4592f83c5d0c419710d293c7500a4b49acde51 Binary files /dev/null and b/assets/examples/sample-16.png differ diff --git a/assets/examples/sample-17.png b/assets/examples/sample-17.png new file mode 100644 index 0000000000000000000000000000000000000000..3e5d215c67f4d8abccb14df5ab4d520f4f0d5410 Binary files /dev/null and b/assets/examples/sample-17.png differ diff --git a/assets/examples/sample-18.png b/assets/examples/sample-18.png new file mode 100644 index 0000000000000000000000000000000000000000..5ee9f0f6548b19598bcd9de563c800e1d419a4c2 Binary files /dev/null and b/assets/examples/sample-18.png differ diff --git a/assets/examples/sample-4.png b/assets/examples/sample-4.png new file mode 100644 index 0000000000000000000000000000000000000000..fd36563d9ffbbd724e2a26d5c548bde7b8281007 Binary files /dev/null and b/assets/examples/sample-4.png differ diff --git a/assets/examples/sample-6.png b/assets/examples/sample-6.png new file mode 100644 index 0000000000000000000000000000000000000000..d142d386d86a6435582dcfcd63df04547bfa248f Binary files /dev/null and b/assets/examples/sample-6.png differ diff --git a/assets/examples/sample-7.png b/assets/examples/sample-7.png new file mode 100644 index 0000000000000000000000000000000000000000..33e833ad6bf946a1db2ab8f63af7ad2d54658ee6 Binary files /dev/null and b/assets/examples/sample-7.png differ diff --git a/configs/accelerate/1-gpu.yaml b/configs/accelerate/1-gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1a05df5c14af1b22beb56e485c490dc3c2682a4 --- /dev/null +++ b/configs/accelerate/1-gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: MULTI_GPU +downcast_bf16: 'no' +dynamo_backend: 'NO' +fsdp_config: {} +gpu_ids: all +machine_rank: 0 +main_training_function: main +megatron_lm_config: {} +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +use_cpu: false \ No newline at end of file diff --git a/configs/accelerate/2-gpu.yaml b/configs/accelerate/2-gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..691fb47960230a006b2fc5db487c08cd62bd7f45 --- /dev/null +++ b/configs/accelerate/2-gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: MULTI_GPU +downcast_bf16: 'no' +dynamo_backend: 'NO' +fsdp_config: {} +gpu_ids: all +machine_rank: 0 +main_training_function: main +megatron_lm_config: {} +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +use_cpu: false \ No newline at end of file diff --git a/configs/accelerate/4-gpu.yaml b/configs/accelerate/4-gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..77b8a73faa6f864b4e39aa58407d790651b2758f --- /dev/null +++ b/configs/accelerate/4-gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: MULTI_GPU +downcast_bf16: 'no' +dynamo_backend: 'NO' +fsdp_config: {} +gpu_ids: all +machine_rank: 0 +main_training_function: main +megatron_lm_config: {} +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +use_cpu: false \ No newline at end of file diff --git a/configs/accelerate/8-gpu.yaml b/configs/accelerate/8-gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ad92a2259c59c0ab263ac1ebdfc4ca41ae6b831 --- /dev/null +++ b/configs/accelerate/8-gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: MULTI_GPU +downcast_bf16: 'no' +dynamo_backend: 'NO' +fsdp_config: {} +gpu_ids: all +machine_rank: 0 +main_training_function: main +megatron_lm_config: {} +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +use_cpu: false \ No newline at end of file diff --git a/configs/accelerate/deepspeed-1-gpu.yaml b/configs/accelerate/deepspeed-1-gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d1ece4a715f4dda7564e71ae3ac3c225396c49e8 --- /dev/null +++ b/configs/accelerate/deepspeed-1-gpu.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +fsdp_config: {} +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/configs/accelerate/deepspeed-2-gpu.yaml b/configs/accelerate/deepspeed-2-gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a18a003c959f3edabf3180c51fab6ac23945a095 --- /dev/null +++ b/configs/accelerate/deepspeed-2-gpu.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +fsdp_config: {} +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/configs/accelerate/deepspeed-4-gpu.yaml b/configs/accelerate/deepspeed-4-gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9d4b37cc73e6d11f82ff0759b431b45e38883e97 --- /dev/null +++ b/configs/accelerate/deepspeed-4-gpu.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +fsdp_config: {} +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/configs/accelerate/deepspeed-8-gpu.yaml b/configs/accelerate/deepspeed-8-gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..359ffbb53f673488f954bb38b364953e9d4beb93 --- /dev/null +++ b/configs/accelerate/deepspeed-8-gpu.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +fsdp_config: {} +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/configs/accelerate/deespeed.json b/configs/accelerate/deespeed.json new file mode 100644 index 0000000000000000000000000000000000000000..451fa0c0283430ff8ef14624540616aa1fde84f1 --- /dev/null +++ b/configs/accelerate/deespeed.json @@ -0,0 +1,29 @@ +{ + "bf16": { + "enabled": false + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu" + }, + "offload_param": { + "device": "cpu" + }, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "sub_group_size": 1e9, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": 4, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/configs/accelerate/val-deepspeed-1-gpu.yaml b/configs/accelerate/val-deepspeed-1-gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2f41f87c8c0219c1f2fd478672c9b87e0b86099 --- /dev/null +++ b/configs/accelerate/val-deepspeed-1-gpu.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +fsdp_config: {} +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/configs/chat-template.jinja b/configs/chat-template.jinja new file mode 100644 index 0000000000000000000000000000000000000000..53512688744a9e895b6a9d994f0778d95c205eca --- /dev/null +++ b/configs/chat-template.jinja @@ -0,0 +1 @@ +{% for message in messages %}{{ message.content }}{% endfor %} \ No newline at end of file diff --git a/configs/generation/hf/starvector-1b/im2svg.yaml b/configs/generation/hf/starvector-1b/im2svg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..50b3d69fd258a0b9355f1b34d0b2fa01644b71f6 --- /dev/null +++ b/configs/generation/hf/starvector-1b/im2svg.yaml @@ -0,0 +1,52 @@ +# General configuration +run: + project_name: "starvector-RL-eval" + out_dir: "/mnt/starvector/RL/logs/eval" + device: cuda + report_to: wandb + run_id: test-run + log_images: false + +# Model configuration +model: + name: "starvector/starvector-1b-im2svg" # Required: Model name for HF-based model + from_checkpoint: false + generation_engine: "hf" + task: im2svg + torch_dtype: float16 + # image_processor: clip # is this needed? + +# Dataset configuration +dataset: + dataset_name: starvector/svg-stack-RL # Required: Name of the dataset to evaluate on + config_name: null # in bigodcs set Image2SVG + split: test + batch_size: 8 + num_workers: 4 + im_size: 224 + num_samples: -1 + +# vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html +# hf https://huggingface.co/docs/transformers/main_classes/text_generation +generation_params: + # Text generation parameters + max_length: 7800 + min_length: 10 + num_beams: 1 + temperature: 0.2 + generation_sweep: false # Controls multi-temperature sampling, rank based sampling + # num_generations_different_temp: 1 + # min_temperature: 0.0 + # max_temperature: 0.5 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 1.0 + presence_penalty: 0.0 # only used in vllm + frequency_penalty: 0.0 + top_p: 0.95 + do_sample: true # turn this off for greedy decoding + use_nucleus_sampling: true + logit_bias: 5 # if this is not false, the model will be biased to the svg_end_token_id + stream: false + + diff --git a/configs/generation/hf/starvector-1b/text2svg.yaml b/configs/generation/hf/starvector-1b/text2svg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..68c6f5af19327c02f9f776ae66715531af83b63e --- /dev/null +++ b/configs/generation/hf/starvector-1b/text2svg.yaml @@ -0,0 +1,44 @@ +# General configuration +run: + project_name: "starvector-RL-eval" + out_dir: "eval_results" + device: cuda + report_to: wandb + run_id: test-run + log_images: false + +# Model configuration +model: + name: "starvector/starvector-1b-text2svg" + generation_engine: "hf" + task: text2svg + torch_dtype: bfloat16 + image_processor: null + +# Dataset configuration +dataset: + name: svg-stack + batch_size: 2 + num_workers: 4 + num_samples: -1 + +# vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html +# hf https://huggingface.co/docs/transformers/main_classes/text_generation +generation_params: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + presence_penalty: 0.0 # only used in vllm + frequency_penalty: 0.0 + top_p: 0.95 + do_sample: true # turn this off for greedy decoding + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 + logit_bias: 10 # if this is not false, the model will be biased to the svg_end_token_id + stream: false \ No newline at end of file diff --git a/configs/generation/hf/starvector-8b/im2svg.yaml b/configs/generation/hf/starvector-8b/im2svg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c94a8fb607a2b1acbcb46c605b50d7e1c7142a4 --- /dev/null +++ b/configs/generation/hf/starvector-8b/im2svg.yaml @@ -0,0 +1,50 @@ +# General configuration +run: + project_name: "starvector-RL-eval" + out_dir: "eval_results" + device: cuda + report_to: wandb + run_id: test-run + log_images: false + +# Model configuration +model: + name: "starvector/starvector-8b-im2svg" + from_checkpoint: false + generation_engine: "hf" + task: im2svg + torch_dtype: bfloat16 + # image_processor: siglip_384 + +# Dataset configuration +dataset: + dataset_name: starvector/svg-stack-RL + config_name: null + split: test + batch_size: 2 + num_workers: 4 + im_size: 384 + num_samples: -1 + +# vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html +# hf https://huggingface.co/docs/transformers/main_classes/text_generation +generation_params: + max_length: 16000 + min_length: 10 + num_beams: 1 + temperature: 0.7 + generation_sweep: false # Controls multi-temperature sampling, rank based sampling + # num_generations_different_temp: 1 + # min_temperature: 0.0 + # max_temperature: 0.5 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + presence_penalty: 0.0 # only used in vllm + frequency_penalty: 0.0 + top_p: 0.95 + do_sample: true # turn this off for greedy decoding + use_nucleus_sampling: true + logit_bias: 5 # if this is not false, the model will be biased to the svg_end_token_id + stream: false + diff --git a/configs/generation/hf/starvector-8b/text2svg.yaml b/configs/generation/hf/starvector-8b/text2svg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..083b2e454e818ef163cbdeb2b46c4784371ed9f2 --- /dev/null +++ b/configs/generation/hf/starvector-8b/text2svg.yaml @@ -0,0 +1,40 @@ +# General configuration +run: + out_dir: "eval_results" + device: cuda + +# Model configuration +model: + name: "starvector/starvector-8b-text2svg" + generation_engine: "hf" + task: text2svg + torch_dtype: bfloat16 + image_processor: clip + +# Dataset configuration +dataset: + name: svg-stack + batch_size: 2 + num_workers: 4 + num_samples: -1 + +# vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html +# hf https://huggingface.co/docs/transformers/main_classes/text_generation +generation_params: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + presence_penalty: 0.0 # only used in vllm + frequency_penalty: 0.0 + top_p: 0.95 + do_sample: true # turn this off for greedy decoding + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 + logit_bias: 10 # if this is not false, the model will be biased to the svg_end_token_id + stream: false \ No newline at end of file diff --git a/configs/generation/vllm/starvector-1b/im2svg.yaml b/configs/generation/vllm/starvector-1b/im2svg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a70c388860a30b32a9eaa8fd739f470f10b35c3 --- /dev/null +++ b/configs/generation/vllm/starvector-1b/im2svg.yaml @@ -0,0 +1,51 @@ +# General configuration +run: + project_name: "starvector-RL-eval" + out_dir: "/mnt/starvector/RL/logs/eval" + report_to: wandb + run_id: test-eval3 + log_images: false + +# Model configuration +model: + name: "starvector/starvector-1b-im2svg" # Required: Model name for HF-based model + from_checkpoint: false + generation_engine: "vllm" + task: im2svg + torch_dtype: float16 + # image_processor: clip # is this needed? + +# Dataset configuration +dataset: + dataset_name: starvector/svg-stack-RL # Required: Name of the dataset to evaluate on + config_name: null # in bigodcs set Image2SVG + split: test + batch_size: 8 + num_workers: 8 + im_size: 224 + num_samples: 500 + +# vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html +# hf https://huggingface.co/docs/transformers/main_classes/text_generation +generation_params: + # Text generation parameters + max_length: 7933 # 8192 - (visual tokens) + num_generations: 1 + min_length: 10 + num_beams: 1 + temperature: 0.5 + generation_sweep: false # Controls multi-temperature sampling, rank based sampling + # num_generations_different_temp: 5 + # min_temperature: 0.0 + # max_temperature: 0.5 + num_captions: 1 + frequency_penalty: 0.0 + presence_penalty: 0.0 + repetition_penalty: 1.0 + top_p: 0.9 + logit_bias: 5 # if this is not false, the model will be biased to the svg_end_token_id + min_p: 0.0 + top_k: -1 + stream: false + + diff --git a/configs/generation/vllm/starvector-1b/text2svg.yaml b/configs/generation/vllm/starvector-1b/text2svg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..585ee21a1bf540a8a9ef41f2a6e802689ffefea8 --- /dev/null +++ b/configs/generation/vllm/starvector-1b/text2svg.yaml @@ -0,0 +1,40 @@ +# General configuration +run: + out_dir: "eval_results" + api: + key: "EMPTY" + base_url: "http://0.0.0.0:40000/v1" + +# Model configuration +model: + name: "starvector/starvector-1b-text2svg" + generation_engine: "vllm" + task: text2svg + torch_dtype: bfloat16 + image_processor: null + +# Dataset configuration +dataset: + name: svg-stack + batch_size: 2 + num_workers: 4 + num_samples: -1 + +# vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html +# hf https://huggingface.co/docs/transformers/main_classes/text_generation +generation_params: + # Text generation parameters + max_length: 16384 + min_length: 10 + num_beams: 1 + temperature: 0.6 + generation_sweep: false # Controls multi-temperature sampling, rank based sampling + # num_generations_different_temp: 1 + # min_temperature: 0.0 + # max_temperature: 0.5 + num_captions: 1 + frequency_penalty: 0.0 + presence_penalty: 0.0 + top_p: 0.9 + logit_bias: False # if this is not false, the model will be biased to the svg_end_token_id + stream: false \ No newline at end of file diff --git a/configs/generation/vllm/starvector-8b/im2svg.yaml b/configs/generation/vllm/starvector-8b/im2svg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..727000f3ecffb9a92ca0ec4e3d32eec6e5907b47 --- /dev/null +++ b/configs/generation/vllm/starvector-8b/im2svg.yaml @@ -0,0 +1,45 @@ +# General configuration +run: + out_dir: "eval_results" + api: + key: "EMPTY" + base_url: "http://0.0.0.0:40000/v1" + +# Model configuration +model: + name: "starvector/starvector-8b-im2svg" + generation_engine: "vllm" + task: im2svg + torch_dtype: bfloat16 + image_processor: siglip_384 + +# Dataset configuration +dataset: + name: svg-stack + batch_size: 2 + num_workers: 4 + im_size: 384 + num_samples: -1 + dpi: 2 + scale: 300 + + +# vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html +# hf https://huggingface.co/docs/transformers/main_classes/text_generation +generation_params: + # Text generation parameters + max_length: 15806 # 16384 - (visual tokens) + min_length: 10 + num_beams: 1 + temperature: 0.6 + generation_sweep: false # Controls multi-temperature sampling, rank based sampling + # num_generations_different_temp: 1 + # min_temperature: 0.0 + # max_temperature: 0.5 + num_captions: 1 + frequency_penalty: 0.0 + presence_penalty: 0.0 + top_p: 0.9 + logit_bias: False # if this is not false, the model will be biased to the svg_end_token_id + stream: false + diff --git a/configs/generation/vllm/starvector-8b/text2svg.yaml b/configs/generation/vllm/starvector-8b/text2svg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..43c6cd5c441e62523ca4c108e6c38538b1d82ac2 --- /dev/null +++ b/configs/generation/vllm/starvector-8b/text2svg.yaml @@ -0,0 +1,41 @@ +# General configuration +run: + out_dir: "eval_results" + api: + key: "EMPTY" + base_url: "http://0.0.0.0:40000/v1" + +# Model configuration +model: + name: "starvector/starvector-8b-text2svg" + generation_engine: "vllm" + task: text2svg + torch_dtype: bfloat16 + image_processor: clip + +# Dataset configuration +dataset: + name: svg-stack + batch_size: 2 + num_workers: 4 + num_samples: -1 + + +# vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html +# hf https://huggingface.co/docs/transformers/main_classes/text_generation +generation_params: + # Text generation parameters + max_length: 16384 + min_length: 10 + num_beams: 1 + temperature: 0.6 + generation_sweep: false # Controls multi-temperature sampling, rank based sampling + # num_generations_different_temp: 1 + # min_temperature: 0.0 + # max_temperature: 0.5 + num_captions: 1 + frequency_penalty: 0.0 + presence_penalty: 0.0 + top_p: 0.9 + logit_bias: False # if this is not false, the model will be biased to the svg_end_token_id + stream: false diff --git a/configs/metrics/im2svg.yaml b/configs/metrics/im2svg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7bd530b33949bd0534ea646c03a3213eeef65112 --- /dev/null +++ b/configs/metrics/im2svg.yaml @@ -0,0 +1,12 @@ +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true diff --git a/configs/metrics/text2svg.yaml b/configs/metrics/text2svg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f8fb670bdcc3b0be824015da270ebcf66b79cfd --- /dev/null +++ b/configs/metrics/text2svg.yaml @@ -0,0 +1,12 @@ +metrics: + L2: false + Masked-L2: false + LPIPS: false + SSIM: false + FID: true + FID_clip: true + CLIPScore: true + CountTokenLength: true + ratio_post_processed: true + ratio_non_compiling: true + DinoScore: false diff --git a/configs/models/default.yaml b/configs/models/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1beef0b1a7aa97d97f68009a0e35c7770f8166a3 --- /dev/null +++ b/configs/models/default.yaml @@ -0,0 +1,93 @@ +project: + project: starvector-im2svg + use_wandb: false + entity: abc + copy_code: false +model: + max_length: 8192 + model_name: null # in case of creating a new model, set this to None (null) + starcoder_model_name: bigcode/starcoderbase-1b + pretrained: true + image_encoder_type: clip + use_flash_attn: true + adapter_norm: layer_norm + init_type: glorot + dropout: 0.1 + task: im2svg + transformer_layer_cls: None # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 3 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 4 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + start_generation_at_step: 0 + train_image_encoder: true + train_LLM: true + use_gradient_checkpointing: false +fsdp: + enable: false +data: + num_workers: 4 + train: + batch_size: 2 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: train + dataset_name: ServiceNow/svg-stack + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false + test: + batch_size: 2 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: test + dataset_name: ServiceNow/svg-stack + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false +generation: + max_length: 8192 + min_length: 10 + num_beams: 3 + num_captions: 1 + num_generations_different_temp: 1.5 + start_temperature: 0.5 + repetition_penalty: 1.0 + length_penalty: 1.0 + temperature: 1.0 + top_p: 0.9 + use_nucleus_sampling: true + im_size: 224 + dpi: 2 + scale: 300 + num_samples_to_generate: -1 + log_wandb_images: true + start_generation_at_step: -1 +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true diff --git a/configs/models/starvector-1b/im2svg-emoji.yaml b/configs/models/starvector-1b/im2svg-emoji.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d17354dd48b43a6ffa324dfa719797df21ec29f --- /dev/null +++ b/configs/models/starvector-1b/im2svg-emoji.yaml @@ -0,0 +1,87 @@ +project: + project: starvector-1b-im2svg + use_wandb: false + entity: abc + copy_code: false +model: + max_length: 8192 + model_name: null # in case of creating a new model, set this to None (null) + starcoder_model_name: bigcode/starcoderbase-1b + pretrained: true + image_encoder_type: clip + use_flash_attn: true + adapter_norm: batch_norm + init_type: glorot + dropout: 0.1 + task: im2svg + transformer_layer_cls: null # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 3 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 10 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + train_image_encoder: true + train_LLM: true + use_gradient_checkpointing: false +fsdp: + enable: false +data: + num_workers: 16 + train: + batch_size: 2 + target: starvector.data.emojisvg.EmojiSVGDataset + params: + split: train + dataset_name: starvector/svg-emoji + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false + test: + batch_size: 8 + target: starvector.data.emojisvg.EmojiSVGDataset + params: + split: test + dataset_name: starvector/svg-emoji + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false +generation: + max_length: 8192 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 224 + dpi: 2 + scale: 300 +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true diff --git a/configs/models/starvector-1b/im2svg-fonts.yaml b/configs/models/starvector-1b/im2svg-fonts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..49ebc43ffff267f290f3c583c3394a20a74e7a69 --- /dev/null +++ b/configs/models/starvector-1b/im2svg-fonts.yaml @@ -0,0 +1,87 @@ +project: + project: starvector-1b-im2svg + use_wandb: false + entity: abc + copy_code: false +model: + max_length: 8192 + model_name: null # in case of creating a new model, set this to None (null) + starcoder_model_name: bigcode/starcoderbase-1b + pretrained: true + image_encoder_type: clip + use_flash_attn: true + adapter_norm: batch_norm + init_type: glorot + dropout: 0.1 + task: im2svg + transformer_layer_cls: null # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 3 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 4 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + train_image_encoder: true + train_LLM: true + use_gradient_checkpointing: false +fsdp: + enable: false +data: + num_workers: 16 + train: + batch_size: 4 + target: starvector.data.fontsvg.FontSVGDataset + params: + split: train + dataset_name: starvector/svg-fonts + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false + test: + batch_size: 8 + target: starvector.data.fontsvg.FontSVGDataset + params: + split: test + dataset_name: starvector/svg-fonts + im_size: 224 + num_samples: 1000 + transforms: false + select_dataset_name: false +generation: + max_length: 8192 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 224 + dpi: 2 + scale: 300 +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true diff --git a/configs/models/starvector-1b/im2svg-icons.yaml b/configs/models/starvector-1b/im2svg-icons.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1b46e9eaa29d478958b50d8a7ade440a33949eea --- /dev/null +++ b/configs/models/starvector-1b/im2svg-icons.yaml @@ -0,0 +1,87 @@ +project: + project: starvector-1b-im2svg + use_wandb: false + entity: abc + copy_code: false +model: + max_length: 8192 + model_name: null # in case of creating a new model, set this to None (null) + starcoder_model_name: bigcode/starcoderbase-1b + pretrained: true + image_encoder_type: clip + use_flash_attn: true + adapter_norm: batch_norm + init_type: glorot + dropout: 0.1 + task: im2svg + transformer_layer_cls: null # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 3 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 10 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + train_image_encoder: true + train_LLM: true + use_gradient_checkpointing: false +fsdp: + enable: false +data: + num_workers: 16 + train: + batch_size: 4 + target: starvector.data.iconsvg.SVGIconsDataset + params: + split: train + dataset_name: starvector/svg-icons + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false + test: + batch_size: 8 + target: starvector.data.iconsvg.SVGIconsDataset + params: + split: test + dataset_name: starvector/svg-icons + im_size: 224 + num_samples: 1000 + transforms: false + select_dataset_name: false +generation: + max_length: 8192 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 224 + dpi: 2 + scale: 300 +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true diff --git a/configs/models/starvector-1b/im2svg-stack.yaml b/configs/models/starvector-1b/im2svg-stack.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f07d964bdea8776c526008182a79d578cbc3411 --- /dev/null +++ b/configs/models/starvector-1b/im2svg-stack.yaml @@ -0,0 +1,87 @@ +project: + project: starvector-1b-im2svg + use_wandb: false + entity: abc + copy_code: false +model: + max_length: 8192 + model_name: null # in case of creating a new model, set this to None (null) + starcoder_model_name: bigcode/starcoderbase-1b + pretrained: true + image_encoder_type: clip + use_flash_attn: true + adapter_norm: batch_norm + init_type: glorot + dropout: 0.1 + task: im2svg + transformer_layer_cls: null # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 3 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 4 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + train_image_encoder: true + train_LLM: true + use_gradient_checkpointing: false +fsdp: + enable: false +data: + num_workers: 4 + train: + batch_size: 2 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: train + dataset_name: starvector/svg-stack + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false + test: + batch_size: 2 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: test + dataset_name: starvector/svg-stack + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false +generation: + max_length: 8192 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 224 + dpi: 2 + scale: 300 +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true diff --git a/configs/models/starvector-1b/text2svg-figr.yaml b/configs/models/starvector-1b/text2svg-figr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a3ab8236fc40e6f849f778a196c0f471825b92eb --- /dev/null +++ b/configs/models/starvector-1b/text2svg-figr.yaml @@ -0,0 +1,94 @@ +project: + project: starvector-1b-text2svg + use_wandb: false + entity: joanrod + copy_code: false +model: + max_length: 8192 + model_name: starvector/starvector-1b-im2svg + starcoder_model_name: bigcode/starcoderbase-1b + pretrained: true + image_encoder_type: clip + use_flash_attn: true + adapter_norm: layer_norm + init_type: normal + dropout: 0.1 + task: text2svg + transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 5 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 2 + lr: 2e-5 + gradient_accumulation_steps: 8 + lr_scheduler: cosine + lr_warmup_steps: 100 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + use_gradient_checkpointing: false + train_image_encoder: false + train_LLM: true +fsdp: + enable: false # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 + cpu_offload: false + sharding_strategy: hsdp + backward_prefetch: BACKWARD_PRE + use_orig_params: true + sync_module_states: true + forward_prefetch: false + cpu_ram_efficient_loading: true +data: + num_workers: 16 + train: + batch_size: 4 + target: starvector.data.figrsvg.FigrSVGDataset + params: + split: train + dataset_name: starvector/FIGR-SVG + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false + test: + batch_size: 8 + target: starvector.data.figrsvg.FigrSVGDataset + params: + split: test + dataset_name: starvector/FIGR-SVG + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false +generation: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 +metrics: + L2: false + Masked-L2: false + LPIPS: false + SSIM: false + FID: false + FID_clip: false + CLIPScore: true + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: false \ No newline at end of file diff --git a/configs/models/starvector-1b/text2svg-stack.yaml b/configs/models/starvector-1b/text2svg-stack.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7982d15c356e4b220ef3a700165b805073a08b7 --- /dev/null +++ b/configs/models/starvector-1b/text2svg-stack.yaml @@ -0,0 +1,96 @@ +project: + project: starvector-1b-text2svg + use_wandb: false + entity: joanrod + copy_code: false +model: + max_length: 8192 + model_name: starvector/starvector-1b-im2svg + starcoder_model_name: bigcode/starcoderbase-1b + pretrained: true + image_encoder_type: clip + use_flash_attn: true + adapter_norm: layer_norm + init_type: normal + dropout: 0.1 + task: text2svg + transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 5 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 2 + lr: 2e-5 + gradient_accumulation_steps: 8 + lr_scheduler: cosine + lr_warmup_steps: 100 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + use_gradient_checkpointing: false + train_image_encoder: false + train_LLM: true +fsdp: + enable: false # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 + cpu_offload: false + sharding_strategy: hsdp + backward_prefetch: BACKWARD_PRE + use_orig_params: true + sync_module_states: true + forward_prefetch: false + cpu_ram_efficient_loading: true +data: + num_workers: 16 + train: + batch_size: 1 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: train + dataset_name: starvector/text2svg-stack + im_size: 224 + num_samples: -1 + transforms: false + select_dataset_name: false + image_processor: siglip_384 + test: + batch_size: 4 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: test + dataset_name: starvector/text2svg-stack + im_size: 224 + num_samples: 64 + transforms: false + select_dataset_name: false + image_processor: siglip_384 +generation: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 +metrics: + L2: false + Masked-L2: false + LPIPS: false + SSIM: false + FID: false + FID_clip: false + CLIPScore: true + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: false \ No newline at end of file diff --git a/configs/models/starvector-8b/im2svg-emoji.yaml b/configs/models/starvector-8b/im2svg-emoji.yaml new file mode 100644 index 0000000000000000000000000000000000000000..46495bc0bc368a25cce371bda99f9f39727b1842 --- /dev/null +++ b/configs/models/starvector-8b/im2svg-emoji.yaml @@ -0,0 +1,94 @@ +project: + project: starvector-8b-im2svg + use_wandb: false + entity: joanrod + copy_code: false +model: + max_length: 16000 + model_name: starvector/starvector-8b-im2svg + starcoder_model_name: bigcode/starcoder2-7b + pretrained: true + image_encoder_type: siglip_384 + use_flash_attn: true + adapter_norm: layer_norm + init_type: normal + dropout: 0.1 + task: im2svg + transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 5 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 4 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + use_gradient_checkpointing: true + train_image_encoder: true + train_LLM: true +fsdp: + enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 + cpu_offload: false + sharding_strategy: hsdp + backward_prefetch: BACKWARD_PRE + use_orig_params: true + sync_module_states: true + forward_prefetch: false + cpu_ram_efficient_loading: true +data: + num_workers: 16 + train: + batch_size: 1 + target: starvector.data.emojisvg.EmojiSVGDataset + params: + split: train + dataset_name: starvector/svg-emoji + im_size: 384 + num_samples: -1 + transforms: false + image_processor: siglip_384 + test: + batch_size: 1 + target: starvector.data.emojisvg.EmojiSVGDataset + params: + split: test + dataset_name: starvector/svg-emoji + im_size: 384 + num_samples: 128 + transforms: false + image_processor: siglip_384 +generation: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true \ No newline at end of file diff --git a/configs/models/starvector-8b/im2svg-fonts-simple.yaml b/configs/models/starvector-8b/im2svg-fonts-simple.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bf4726b4679a000661219049ccbb102c7dbaed31 --- /dev/null +++ b/configs/models/starvector-8b/im2svg-fonts-simple.yaml @@ -0,0 +1,94 @@ +project: + project: starvector-8b-im2svg + use_wandb: false + entity: joanrod + copy_code: false +model: + max_length: 16000 + model_name: starvector/starvector-8b-im2svg + starcoder_model_name: bigcode/starcoder2-7b + pretrained: true + image_encoder_type: siglip_384 + use_flash_attn: true + adapter_norm: layer_norm + init_type: normal + dropout: 0.1 + task: im2svg + transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 5 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 4 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + use_gradient_checkpointing: true + train_image_encoder: true + train_LLM: true +fsdp: + enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 + cpu_offload: false + sharding_strategy: hsdp + backward_prefetch: BACKWARD_PRE + use_orig_params: true + sync_module_states: true + forward_prefetch: false + cpu_ram_efficient_loading: true +data: + num_workers: 16 + train: + batch_size: 1 + target: starvector.data.fontsvg.FontSVGDataset + params: + split: train + dataset_name: starvector/svg-fonts-simple + im_size: 384 + num_samples: -1 + transforms: false + image_processor: siglip_384 + test: + batch_size: 4 + target: starvector.data.fontsvg.FontSVGDataset + params: + split: test + dataset_name: starvector/svg-fonts-simple + im_size: 384 + num_samples: 128 + transforms: false + image_processor: siglip_384 +generation: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true \ No newline at end of file diff --git a/configs/models/starvector-8b/im2svg-fonts.yaml b/configs/models/starvector-8b/im2svg-fonts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c947c038daae410e3823f8df68a84885858cbb27 --- /dev/null +++ b/configs/models/starvector-8b/im2svg-fonts.yaml @@ -0,0 +1,94 @@ +project: + project: starvector-8b-im2svg + use_wandb: false + entity: joanrod + copy_code: false +model: + max_length: 16000 + model_name: starvector/starvector-8b-im2svg + starcoder_model_name: bigcode/starcoder2-7b + pretrained: true + image_encoder_type: siglip_384 + use_flash_attn: true + adapter_norm: layer_norm + init_type: normal + dropout: 0.1 + task: im2svg + transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 5 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 4 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + use_gradient_checkpointing: true + train_image_encoder: true + train_LLM: true +fsdp: + enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 + cpu_offload: false + sharding_strategy: hsdp + backward_prefetch: BACKWARD_PRE + use_orig_params: true + sync_module_states: true + forward_prefetch: false + cpu_ram_efficient_loading: true +data: + num_workers: 16 + train: + batch_size: 1 + target: starvector.data.fontsvg.FontSVGDataset + params: + split: train + dataset_name: starvector/svg-fonts + im_size: 384 + num_samples: -1 + transforms: false + image_processor: siglip_384 + test: + batch_size: 4 + target: starvector.data.fontsvg.FontSVGDataset + params: + split: test + dataset_name: starvector/svg-fonts + im_size: 384 + num_samples: 128 + transforms: false + image_processor: siglip_384 +generation: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true \ No newline at end of file diff --git a/configs/models/starvector-8b/im2svg-icons.yaml b/configs/models/starvector-8b/im2svg-icons.yaml new file mode 100644 index 0000000000000000000000000000000000000000..396193d591f17dfffab3978c871fa7c5c56f70dd --- /dev/null +++ b/configs/models/starvector-8b/im2svg-icons.yaml @@ -0,0 +1,94 @@ +project: + project: starvector-8b-im2svg + use_wandb: false + entity: joanrod + copy_code: false +model: + max_length: 16000 + model_name: starvector/starvector-8b-im2svg + starcoder_model_name: bigcode/starcoder2-7b + pretrained: true + image_encoder_type: siglip_384 + use_flash_attn: true + adapter_norm: layer_norm + init_type: normal + dropout: 0.1 + task: im2svg + transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 5 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 4 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + use_gradient_checkpointing: true + train_image_encoder: true + train_LLM: true +fsdp: + enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 + cpu_offload: false + sharding_strategy: hsdp + backward_prefetch: BACKWARD_PRE + use_orig_params: true + sync_module_states: true + forward_prefetch: false + cpu_ram_efficient_loading: true +data: + num_workers: 16 + train: + batch_size: 1 + target: starvector.data.iconsvg.SVGIconsDataset + params: + split: train + dataset_name: starvector/svg-icons + im_size: 384 + num_samples: -1 + transforms: false + image_processor: siglip_384 + test: + batch_size: 1 + target: starvector.data.iconsvg.SVGIconsDataset + params: + split: test + dataset_name: starvector/svg-icons + im_size: 384 + num_samples: 128 + transforms: false + image_processor: siglip_384 +generation: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true \ No newline at end of file diff --git a/configs/models/starvector-8b/im2svg-stack.yaml b/configs/models/starvector-8b/im2svg-stack.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e5e336ce6a0c32937a274bea0700524ce42732e5 --- /dev/null +++ b/configs/models/starvector-8b/im2svg-stack.yaml @@ -0,0 +1,96 @@ +project: + project: starvector-8b-im2svg + use_wandb: false + entity: joanrod + copy_code: false +model: + max_length: 16000 + model_name: null + starcoder_model_name: bigcode/starcoder2-7b + pretrained: true + image_encoder_type: siglip_384 + use_flash_attn: true + adapter_norm: layer_norm + init_type: normal + dropout: 0.1 + task: im2svg + transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 5 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 4 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + use_gradient_checkpointing: true + train_image_encoder: true + train_LLM: true +fsdp: + enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 + cpu_offload: false + sharding_strategy: hsdp + backward_prefetch: BACKWARD_PRE + use_orig_params: true + sync_module_states: true + forward_prefetch: false + cpu_ram_efficient_loading: true +data: + num_workers: 16 + train: + batch_size: 1 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: train + dataset_name: starvector/svg-stack + im_size: 384 + num_samples: -1 + transforms: false + select_dataset_name: false + image_processor: siglip_384 + test: + batch_size: 2 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: test + dataset_name: starvector/svg-stack + im_size: 384 + num_samples: 64 + transforms: false + select_dataset_name: false + image_processor: siglip_384 +generation: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 +metrics: + L2: true + Masked-L2: false + LPIPS: true + SSIM: true + FID: false + FID_clip: false + CLIPScore: false + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: true \ No newline at end of file diff --git a/configs/models/starvector-8b/text2svg-figr.yaml b/configs/models/starvector-8b/text2svg-figr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9783a5c676bca1bb9ef3d1dce522e944b25db9f3 --- /dev/null +++ b/configs/models/starvector-8b/text2svg-figr.yaml @@ -0,0 +1,96 @@ +project: + project: starvector-8b-text2svg + use_wandb: false + entity: joanrod + copy_code: false +model: + max_length: 16000 + model_name: starvector/starvector-8b-im2svg + starcoder_model_name: bigcode/starcoder2-7b + pretrained: true + image_encoder_type: siglip_384 + use_flash_attn: true + adapter_norm: layer_norm + init_type: normal + dropout: 0.1 + task: text2svg + transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 5 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 10 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + use_gradient_checkpointing: true + train_image_encoder: true + train_LLM: true +fsdp: + enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 + cpu_offload: false + sharding_strategy: hsdp + backward_prefetch: BACKWARD_PRE + use_orig_params: true + sync_module_states: true + forward_prefetch: false + cpu_ram_efficient_loading: true +data: + num_workers: 16 + train: + batch_size: 4 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: train + dataset_name: starvector/FIGR-SVG + im_size: 384 + num_samples: -1 + transforms: false + select_dataset_name: false + image_processor: siglip_384 + test: + batch_size: 4 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: test + dataset_name: starvector/FIGR-SVG + im_size: 384 + num_samples: 64 + transforms: false + select_dataset_name: false + image_processor: siglip_384 +generation: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 +metrics: + L2: false + Masked-L2: false + LPIPS: false + SSIM: false + FID: false + FID_clip: false + CLIPScore: true + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: false \ No newline at end of file diff --git a/configs/models/starvector-8b/text2svg-stack.yaml b/configs/models/starvector-8b/text2svg-stack.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aef26da9a006d247627ba2f1f3c35f86ba22065f --- /dev/null +++ b/configs/models/starvector-8b/text2svg-stack.yaml @@ -0,0 +1,96 @@ +project: + project: starvector-8b-text2svg + use_wandb: false + entity: joanrod + copy_code: false +model: + max_length: 16000 + model_name: starvector/starvector-8b-im2svg + starcoder_model_name: bigcode/starcoder2-7b + pretrained: true + image_encoder_type: siglip_384 + use_flash_attn: true + adapter_norm: layer_norm + init_type: normal + dropout: 0.1 + task: text2svg + transformer_layer_cls: Starcoder2DecoderLayer # fsdp specific + use_cache: false +training: + save_model_epochs: 1 + checkpointing_steps: 500 + checkpoints_total_limit: 5 + model_precision: bf16 + resume_from_checkpoint: false + continue_training: false + n_epochs: 4 + lr: 0.00001 + gradient_accumulation_steps: 4 + lr_scheduler: cosine + lr_warmup_steps: 10 + adam_beta1: 0.95 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-06 + adam_epsilon: 1e-08 + optimizer: adamw + use_gradient_checkpointing: true + train_image_encoder: true + train_LLM: true +fsdp: + enable: true # TODO: set this reasonably, i.e., false only if you want to use DDP or have PyTorch < 2.1 + cpu_offload: false + sharding_strategy: hsdp + backward_prefetch: BACKWARD_PRE + use_orig_params: true + sync_module_states: true + forward_prefetch: false + cpu_ram_efficient_loading: true +data: + num_workers: 16 + train: + batch_size: 4 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: train + dataset_name: starvector/text2svg-stack + im_size: 384 + num_samples: -1 + transforms: false + select_dataset_name: false + image_processor: siglip_384 + test: + batch_size: 4 + target: starvector.data.stacksvg.SVGStackDataset + params: + split: test + dataset_name: starvector/text2svg-stack + im_size: 384 + num_samples: 64 + transforms: false + select_dataset_name: false + image_processor: siglip_384 +generation: + max_length: 10000 + min_length: 10 + num_beams: 3 + temperature: 1.0 + num_captions: 1 + repetition_penalty: 1.0 + length_penalty: 0.5 + top_p: 0.95 + use_nucleus_sampling: true + im_size: 384 + dpi: 2 + scale: 300 +metrics: + L2: false + Masked-L2: false + LPIPS: false + SSIM: false + FID: false + FID_clip: false + CLIPScore: true + CountTokenLength: true + ratio_post_processed: false + ratio_non_compiling: false + DinoScore: false \ No newline at end of file diff --git a/controller.log b/controller.log new file mode 100644 index 0000000000000000000000000000000000000000..244095438d6b39bd7aad41faee9a156fc7968ad7 --- /dev/null +++ b/controller.log @@ -0,0 +1,31 @@ +2025-03-23 15:00:44 | INFO | controller | args: Namespace(host='0.0.0.0', port=10000, dispatch_method='shortest_queue') +2025-03-23 15:00:44 | INFO | controller | Init controller +2025-03-23 15:00:45 | ERROR | stderr | INFO: Started server process [48368] +2025-03-23 15:00:45 | ERROR | stderr | INFO: Waiting for application startup. +2025-03-23 15:00:45 | ERROR | stderr | INFO: Application startup complete. +2025-03-23 15:00:45 | ERROR | stderr | INFO: Uvicorn running on http://0.0.0.0:10000 (Press CTRL+C to quit) +2025-03-23 15:01:04 | INFO | controller | Register a new worker: http://localhost:40000 +2025-03-23 15:01:04 | INFO | controller | Register done: http://localhost:40000, {'model_names': ['/home/agent_h/data/starvector-1b-im2svg'], 'speed': 1, 'queue_length': 0} +2025-03-23 15:01:04 | INFO | stdout | INFO: 127.0.0.1:51486 - "POST /register_worker HTTP/1.1" 200 OK +2025-03-23 15:01:19 | INFO | controller | Receive heart beat. http://localhost:40000 +2025-03-23 15:01:19 | INFO | stdout | INFO: 127.0.0.1:51523 - "POST /receive_heart_beat HTTP/1.1" 200 OK +2025-03-23 15:01:34 | INFO | controller | Receive heart beat. http://localhost:40000 +2025-03-23 15:01:34 | INFO | stdout | INFO: 127.0.0.1:51562 - "POST /receive_heart_beat HTTP/1.1" 200 OK +2025-03-23 15:01:49 | INFO | controller | Receive heart beat. http://localhost:40000 +2025-03-23 15:01:49 | INFO | stdout | INFO: 127.0.0.1:51607 - "POST /receive_heart_beat HTTP/1.1" 200 OK +2025-03-23 15:02:04 | INFO | controller | Receive heart beat. http://localhost:40000 +2025-03-23 15:02:04 | INFO | stdout | INFO: 127.0.0.1:51648 - "POST /receive_heart_beat HTTP/1.1" 200 OK +2025-03-23 15:02:19 | INFO | controller | Receive heart beat. http://localhost:40000 +2025-03-23 15:02:19 | INFO | stdout | INFO: 127.0.0.1:51683 - "POST /receive_heart_beat HTTP/1.1" 200 OK +2025-03-23 15:02:34 | INFO | controller | Receive heart beat. http://localhost:40000 +2025-03-23 15:02:34 | INFO | stdout | INFO: 127.0.0.1:51729 - "POST /receive_heart_beat HTTP/1.1" 200 OK +2025-03-23 15:02:55 | ERROR | stderr | INFO: Shutting down +2025-03-23 15:02:55 | ERROR | stderr | INFO: Waiting for application shutdown. +2025-03-23 15:02:55 | ERROR | stderr | INFO: Application shutdown complete. +2025-03-23 15:02:55 | ERROR | stderr | INFO: Finished server process [48368] +2025-03-23 15:04:32 | INFO | controller | args: Namespace(host='0.0.0.0', port=10000, dispatch_method='shortest_queue') +2025-03-23 15:04:32 | INFO | controller | Init controller +2025-03-23 15:04:32 | ERROR | stderr | INFO: Started server process [50695] +2025-03-23 15:04:32 | ERROR | stderr | INFO: Waiting for application startup. +2025-03-23 15:04:32 | ERROR | stderr | INFO: Application startup complete. +2025-03-23 15:04:32 | ERROR | stderr | INFO: Uvicorn running on http://0.0.0.0:10000 (Press CTRL+C to quit) diff --git a/model_worker_ad9563.log b/model_worker_ad9563.log new file mode 100644 index 0000000000000000000000000000000000000000..c8415522730664a884cd52c906593a3cb2e34fe8 --- /dev/null +++ b/model_worker_ad9563.log @@ -0,0 +1,17 @@ +2025-03-23 15:01:04 | INFO | model_worker | args: Namespace(host='0.0.0.0', port=40000, worker_address='http://localhost:40000', controller_address='http://localhost:10000', model_name='/home/agent_h/data/starvector-1b-im2svg', multi_modal=False, limit_model_concurrency=5, stream_interval=1, no_register=False, openai_api_key='EMPTY', vllm_base_url='http://localhost:8000') +2025-03-23 15:01:04 | INFO | model_worker | Loading the model /home/agent_h/data/starvector-1b-im2svg on worker ad9563 ... +2025-03-23 15:01:04 | INFO | model_worker | Register to controller +2025-03-23 15:01:04 | ERROR | stderr | INFO: Started server process [48407] +2025-03-23 15:01:04 | ERROR | stderr | INFO: Waiting for application startup. +2025-03-23 15:01:04 | ERROR | stderr | INFO: Application startup complete. +2025-03-23 15:01:04 | ERROR | stderr | INFO: Uvicorn running on http://0.0.0.0:40000 (Press CTRL+C to quit) +2025-03-23 15:01:19 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0 +2025-03-23 15:01:34 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0 +2025-03-23 15:01:49 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0 +2025-03-23 15:02:04 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0 +2025-03-23 15:02:19 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0 +2025-03-23 15:02:34 | INFO | model_worker | Send heart beat. Models: ['/home/agent_h/data/starvector-1b-im2svg']. Semaphore: None. global_counter: 0 +2025-03-23 15:02:45 | ERROR | stderr | INFO: Shutting down +2025-03-23 15:02:45 | ERROR | stderr | INFO: Waiting for application shutdown. +2025-03-23 15:02:45 | ERROR | stderr | INFO: Application shutdown complete. +2025-03-23 15:02:45 | ERROR | stderr | INFO: Finished server process [48407] diff --git a/models/modeling_starvector.py b/models/modeling_starvector.py new file mode 100644 index 0000000000000000000000000000000000000000..6a30af17bebc9284274f412539b32e2d48a792a6 --- /dev/null +++ b/models/modeling_starvector.py @@ -0,0 +1,224 @@ +""" +This module defines a self-contained StarVector model with support for remote code loading. +""" + +import os +import torch +import torch.nn as nn +from transformers import PreTrainedModel, PretrainedConfig +from typing import Optional, Union, List +from abc import ABC, abstractmethod + +# Import components - these will be included in the HF repo +from .starvector.image_encoder import ImageEncoder +from .starvector.adapter import Adapter + +# === Model Configuration === + +class StarVectorConfig(PretrainedConfig): + model_type = "starvector" + + def __init__( + self, + starcoder_model_name: str = "bigcode/starcoderbase-1b", + image_encoder_type: str = "clip", + adapter_norm: str = "layer_norm", + image_size: int = 224, + max_length: int = 8192, + max_length_train: int = 8192, + use_flash_attn: bool = True, + use_cache: bool = True, + num_attention_heads: int = 16, + num_hidden_layers: int = 24, + vocab_size: int = 49152, + hidden_size: int = 2048, + num_kv_heads: int = 4, + torch_dtype: str = "bfloat16", + **kwargs, + ): + # Initialize the parent config first + super().__init__(**kwargs) + self.starcoder_model_name = starcoder_model_name + self.image_encoder_type = image_encoder_type + self.adapter_norm = adapter_norm + self.image_size = image_size + self.max_length = max_length + self.max_length_train = max_length_train + self.use_flash_attn = use_flash_attn + self.use_cache = use_cache + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_kv_heads = num_kv_heads + self.torch_dtype = torch_dtype + +# === Base Model Classes === + +class StarVectorBase(nn.Module, ABC): + def __init__(self, config, **kwargs): + super().__init__() + self.task = kwargs.get('task', 'im2svg') + self.model_precision = kwargs.get('model_precision', config.torch_dtype) + + # Instantiate the SVG transformer using the abstract method. + self.svg_transformer = self._get_svg_transformer(config, **kwargs) + + if self.use_image_encoder(): + self.image_encoder = ImageEncoder(config, **kwargs) + self.image_projection = self.get_adapter(config, **kwargs).to(dtype=self.model_precision) + else: + self.query_length = 0 + + self.max_length = config.max_length_train - getattr(self, "query_length", 0) - 4 + self.train_image_encoder = kwargs.get('train_image_encoder', False) + self.train_LLM = kwargs.get('train_LLM', False) + + self._freeze_parameters(self.train_image_encoder, self.train_LLM) + + @abstractmethod + def _get_svg_transformer(self, config, **kwargs): + """Get SVG transformer model - implementation differs between versions""" + pass + + def _freeze_parameters(self, train_image_encoder, train_LLM): + if self.use_image_encoder(): + for _, param in self.image_encoder.named_parameters(): + param.requires_grad = train_image_encoder + for _, param in self.image_projection.named_parameters(): + param.requires_grad = train_image_encoder + for _, param in self.svg_transformer.transformer.named_parameters(): + param.requires_grad = train_LLM + + def use_image_encoder(self): + return self.task == 'im2svg' + + def get_adapter(self, config, **kwargs): + # Determine hidden size and query length based on the image encoder type. + if config.image_encoder_type == 'clip': + hidden_size = self.image_encoder.num_features + self.query_length = 257 + elif config.image_encoder_type == 'vqgan': + hidden_size = 256 + self.query_length = 196 + else: + hidden_size = 256 # default fallback + self.query_length = 200 + llm_hidden_size = config.hidden_size # assuming the transformer hidden size + return Adapter(hidden_size, llm_hidden_size, adapter_norm=config.adapter_norm, query_length=self.query_length, dropout_prob=kwargs.get('dropout', 0.1)) + + def forward(self, batch): + # Simplified forward pass where we assume batch has an "image" key. + image = batch["image"] + if self.use_image_encoder(): + embedded_image = self.image_encoder(image) + conditioning_embeds = self.image_projection(embedded_image) + # For demo purposes, we generate dummy input embeddings (replace with your logic) + inputs_embeds = self.svg_transformer.transformer.wte( + torch.randint(0, self.svg_transformer.transformer.wte.num_embeddings, (image.size(0), self.max_length)) + ) + else: + inputs_embeds = self.svg_transformer.transformer.wte( + torch.randint(0, self.svg_transformer.transformer.wte.num_embeddings, (image.size(0), self.max_length)) + ) + return inputs_embeds # Dummy return + + def generate_im2svg(self, batch, **kwargs): + # Prepare generation inputs (dummy implementation) + image = batch["image"] + if self.use_image_encoder(): + embedded_image = self.image_encoder(image) + conditioning_embeds = self.image_projection(embedded_image) + else: + conditioning_embeds = torch.zeros((image.size(0), 10, 1), device=image.device) + generation_output = self.svg_transformer.transformer.generate(inputs_embeds=conditioning_embeds, max_length=kwargs.get('max_length', 30)) + raw_svg = self.svg_transformer.tokenizer.batch_decode(generation_output, skip_special_tokens=True) + return raw_svg + + @abstractmethod + def _get_embeddings(self, input_ids): + """Get embeddings from input ids - implementation differs between v1 and v2""" + pass + + @abstractmethod + def _get_svg_text(self, svg_list): + """Get SVG text with appropriate end tokens - implementation differs between v1 and v2""" + pass + +# V1 implementation: Delegates transformer creation to the external LLM file. +class StarVectorStarCoder(StarVectorBase): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + + def _get_svg_transformer(self, config, **kwargs): + from starvector.model.llm.starcoder import StarCoderModel # V1: use StarCoderModel from external file + return StarCoderModel(config, **kwargs) + + def _get_embeddings(self, input_ids): + """V1-specific embedding method""" + # This follows the implementation in starvector/model/models/starvector_v1.py. + return self.svg_transformer.transformer.transformer.wte(input_ids) + + def _get_svg_text(self, svg_list): + """V1-specific SVG text preparation""" + return [t + self.svg_transformer.tokenizer.eos_token for t in svg_list] + +# V2 implementation: Delegates transformer creation to the external V2 LLM file. +class StarVectorStarCoder2(StarVectorBase): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + + def _get_svg_transformer(self, config, **kwargs): + from starvector.model.llm.starcoder2 import StarCoderModel # V2: use external StarCoderModel from starcoder2.py + return StarCoderModel(config, **kwargs) + + def _get_embeddings(self, input_ids): + """V2-specific embedding method""" + return self.svg_transformer.transformer.model.embed_tokens(input_ids) + + def _get_svg_text(self, svg_list): + """V2-specific SVG text preparation""" + return [t + self.svg_transformer.svg_end_token + self.svg_transformer.tokenizer.eos_token for t in svg_list] + + def _get_im2svg_specific_kwargs(self, kwargs): + """V2-specific generation kwargs""" + return { + 'eos_token_id': self.svg_transformer.svg_end_token_id, + } + + def _get_text2svg_specific_kwargs(self, kwargs): + """V2-specific text2svg generation kwargs""" + return { + 'eos_token_id': self.svg_transformer.tokenizer.eos_token_id, + } + +# === Main Model Class for Hugging Face === + +class StarVectorForCausalLM(PreTrainedModel): + config_class = StarVectorConfig + _no_split_modules = [] + + def __init__(self, config, **kwargs): + super().__init__(config) + # Choose V2 if the model name indicates starcoder2; otherwise use V1. + if "starcoder2" in config.starcoder_model_name.lower(): + self.model = StarVectorStarCoder2(config=config, **kwargs) + else: + self.model = StarVectorStarCoder(config=config, **kwargs) + + def forward(self, batch): + return self.model(batch) + + def generate_im2svg(self, batch, **kwargs): + return self.model.generate_im2svg(batch, **kwargs) + + def generate_im2text(self, batch, **kwargs): + return self.model.generate_im2text(batch, **kwargs) + + def process_images(self, images): + return self.model.image_encoder.process_images(images) + +# === Registration for Autonomous Loading === + +StarVectorConfig.register_for_auto_class() +StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM") \ No newline at end of file diff --git a/models/starvector/adapter.py b/models/starvector/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb52be52063fa211dec255ed363f4a4afae1ee6 --- /dev/null +++ b/models/starvector/adapter.py @@ -0,0 +1,15 @@ +import torch.nn as nn + +class Adapter(nn.Module): + def __init__(self, in_features, out_features, adapter_norm="layer_norm", query_length=1, dropout_prob=0.1): + super().__init__() + self.fc = nn.Linear(in_features, out_features) + self.norm = nn.LayerNorm(out_features) if adapter_norm == "layer_norm" else None + self.dropout = nn.Dropout(dropout_prob) + self.query_length = query_length + + def forward(self, x): + out = self.fc(x) + if self.norm is not None: + out = self.norm(out) + return self.dropout(out) \ No newline at end of file diff --git a/models/starvector/clip_model.py b/models/starvector/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..088eb25e2c2a39f31086fb5cc8c2dc714894c0cc --- /dev/null +++ b/models/starvector/clip_model.py @@ -0,0 +1 @@ +# Copy your CLIP model implementation here \ No newline at end of file diff --git a/models/starvector/image_encoder.py b/models/starvector/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ac758bb271df62d4cabbcc467c9d945a38ebcc --- /dev/null +++ b/models/starvector/image_encoder.py @@ -0,0 +1,2 @@ +from .clip_model import convert_weights_to_precision, VisionTransformer, LayerNorm +# ... rest of ImageEncoder implementation ... \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..6197c842e8639798f3129f0a11b75ea00da569a5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,70 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "starvector" +version = "1.0" +description = "Generating Scalable Vector Graphics Code from Images and Text" +readme = "README.md" +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", +] +dependencies = [ + "torch==2.5.1", + "torchvision==0.20.1", + "transformers==4.49.0", + "tokenizers==0.21.1", + "sentencepiece==0.2.0", + "accelerate", + "pydantic==2.10", + "markdown2[all]", + "numpy<2.0.0", + "scikit-learn==1.2.2", + "gradio==3.36.1", + "gradio_client==0.2.9", + "requests", + "httpx==0.24.0", + "uvicorn", + "fastapi", + "svgpathtools==1.6.1", + "seaborn==0.12.2", + "taming-transformers", + "lpips", + "cairosvg", + "beautifulsoup4", + "webcolors", + "tqdm", + "omegaconf", + "open-clip-torch", + "noise", + "datasets", + "scikit-image", + "fairscale", + "lxml", + "torch-fidelity", + "clip-openai", + "scipy==1.11.1", + "sentence-transformers", + "reportlab", + "svglib", + "Pillow", + "protobuf", + "openai", + +] + +[project.optional-dependencies] +train = ["deepspeed", "ninja", "wandb"] + +[project.urls] +"Homepage" = "https://starvector.github.io" +"Bug Tracker" = "https://github.com/joanrod/starvector/issues" + +[tool.setuptools.packages.find] +exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] + +[tool.wheel] +exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] diff --git a/requirements.txt b/requirements.txt index 2984f2bea4ebe938fdd5839b08ace9ad3a51039e..ecf975e2fa63a1383916f0c755e0db3e0b3d62b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ --e ./star-vector-dev \ No newline at end of file +-e . \ No newline at end of file diff --git a/scripts/quickstart-hf.py b/scripts/quickstart-hf.py new file mode 100644 index 0000000000000000000000000000000000000000..f702ce40d5a3fa29e78e7ff32c4674d8cc946467 --- /dev/null +++ b/scripts/quickstart-hf.py @@ -0,0 +1,24 @@ +from PIL import Image +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor +from starvector.data.util import process_and_rasterize_svg +import torch + +# model_name = "starvector/starvector-1b-im2svg" +model_name = "starvector/starvector-8b-im2svg" + +starvector = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True) +processor = starvector.model.processor +tokenizer = starvector.model.svg_transformer.tokenizer + +starvector.cuda() +starvector.eval() + +image_pil = Image.open('assets/examples/sample-18.png') + +image = processor(image_pil, return_tensors="pt")['pixel_values'].cuda() +if not image.shape[0] == 1: + image = image.squeeze(0) +batch = {"image": image} + +raw_svg = starvector.generate_im2svg(batch, max_length=100)[0] +svg, raster_image = process_and_rasterize_svg(raw_svg) diff --git a/scripts/quickstart-vllm.py b/scripts/quickstart-vllm.py new file mode 100644 index 0000000000000000000000000000000000000000..edbf9f71559dccc3a2f0def48e381d2704dbaad7 --- /dev/null +++ b/scripts/quickstart-vllm.py @@ -0,0 +1,37 @@ +from PIL import Image +from vllm import LLM, SamplingParams + +model_name = "starvector/starvector-1b-im2svg" +# model_name = "starvector/starvector-8b-im2svg" + +sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=7900, + n=1, + frequency_penalty=0.0, + repetition_penalty=1.0, + top_k=-1, + min_p=0.0, +) +llm = LLM(model=model_name, trust_remote_code=True, max_model_len=8192) + +prompt_start = "" +images = [Image.open('assets/examples/sample-18.png')] +model_inputs_vllm = [] +for i in range(len(images)): + model_inputs_vllm.append({ + "prompt": prompt_start, + "multi_modal_data": {"image": images[i]} + }) + +outputs = llm.generate(model_inputs_vllm, + sampling_params=sampling_params, + use_tqdm=False) + +completions = [] +for i in range(len(outputs)): + for j in range(len(outputs[i].outputs)): + completions.append(outputs[i].outputs[j].text) + +print(completions) diff --git a/scripts/quickstart.py b/scripts/quickstart.py new file mode 100644 index 0000000000000000000000000000000000000000..145d3fa58f2b83aefe8b58802652926fbb7a21b4 --- /dev/null +++ b/scripts/quickstart.py @@ -0,0 +1,20 @@ +from PIL import Image +from starvector.model.starvector_arch import StarVectorForCausalLM +from starvector.data.util import process_and_rasterize_svg +import torch + +model_name = "starvector/starvector-1b-im2svg" +# model_name = "starvector/starvector-8b-im2svg" + +starvector = StarVectorForCausalLM.from_pretrained(model_name, torch_dtype="auto") # add , torch_dtype="bfloat16" + +starvector.cuda() +starvector.eval() + +image_pil = Image.open("assets/examples/sample-18.png") +image_pil = image_pil.convert('RGB') +image = starvector.process_images([image_pil])[0].to(torch.float16).cuda() +batch = {"image": image} + +raw_svg = starvector.generate_im2svg(batch, max_length=4000, temperature=1.5, length_penalty=-1, repetition_penalty=3.1)[0] +svg, raster_image = process_and_rasterize_svg(raw_svg) diff --git a/scripts/train/train-starvector-1b-im2svg.sh b/scripts/train/train-starvector-1b-im2svg.sh new file mode 100644 index 0000000000000000000000000000000000000000..c43f9361ad24ca4ef15eb974f6aaecf65d3bb915 --- /dev/null +++ b/scripts/train/train-starvector-1b-im2svg.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +export HF_HOME= +export HF_TOKEN= +export WANDB_API_KEY= +export OUTPUT_DIR= + +accelerate launch --config_file configs/accelerate/deepspeed-8-gpu.yaml \ + starvector/train/train.py \ + config=configs/models/starvector-1b/im2svg-stack.yaml diff --git a/scripts/train/train-starvector-1b-text2svg.sh b/scripts/train/train-starvector-1b-text2svg.sh new file mode 100644 index 0000000000000000000000000000000000000000..5cd13a24b44e3bd45081b1b903505f14a4c06eaf --- /dev/null +++ b/scripts/train/train-starvector-1b-text2svg.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +export HF_HOME= +export HF_TOKEN= +export WANDB_API_KEY= +export OUTPUT_DIR= + +accelerate launch --config_file configs/accelerate/deepspeed-8-gpu.yaml \ + starvector/train/train.py \ + config=configs/models/starvector-1b/text2svg-stack.yaml diff --git a/scripts/train/train-starvector-8b-im2svg.sh b/scripts/train/train-starvector-8b-im2svg.sh new file mode 100644 index 0000000000000000000000000000000000000000..56eb60aee0ae73c99c64541787ca78f8ebee61c1 --- /dev/null +++ b/scripts/train/train-starvector-8b-im2svg.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +export HF_HOME= +export HF_TOKEN= +export WANDB_API_KEY= +export OUTPUT_DIR= + +torchrun \ + --nproc-per-node=2 \ + --nnodes=1 \ + starvector/train/train.py \ + config=configs/models/starvector-8b/im2svg-stack.yaml diff --git a/scripts/train/train-starvector-8b-text2svg.sh b/scripts/train/train-starvector-8b-text2svg.sh new file mode 100644 index 0000000000000000000000000000000000000000..4260fd199f294e63fbbc6d1327a47d0ff50280bd --- /dev/null +++ b/scripts/train/train-starvector-8b-text2svg.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +export HF_HOME= +export HF_TOKEN= +export WANDB_API_KEY= +export OUTPUT_DIR= + +torchrun \ + --nproc-per-node=2 \ + --nnodes=1 \ + starvector/train/train.py \ + config=configs/models/starvector-8b/text2svg-stack.yaml diff --git a/scripts/validation/validate-starvector-1b-im2svg.sh b/scripts/validation/validate-starvector-1b-im2svg.sh new file mode 100644 index 0000000000000000000000000000000000000000..83143df3f52baebbde823e6f589c506eb98af88b --- /dev/null +++ b/scripts/validation/validate-starvector-1b-im2svg.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +python starvector/validation/run_validator.py \ +config=configs/generation/starvector-1b/im2svg.yaml \ +dataset.name svg-stack \ +model.generation_engine=hf + +python starvector/validation/run_validator.py \ +config=configs/generation/starvector-1b/im2svg.yaml \ +dataset.name svg-emoji \ +model.generation_engine=hf + +python starvector/validation/run_validator.py \ +config=configs/generation/starvector-1b/im2svg.yaml \ +dataset.name svg-fonts \ +model.generation_engine=hf + +python starvector/validation/run_validator.py \ +config=configs/generation/starvector-1b/im2svg.yaml \ +dataset.name svg-diagrams \ +model.generation_engine=hf + +python starvector/validation/run_validator.py \ +config=configs/generation/starvector-1b/im2svg.yaml \ +dataset.name svg-icons \ +model.generation_engine=hf + + + diff --git a/scripts/validation/validate-starvector-8b-im2svg.sh b/scripts/validation/validate-starvector-8b-im2svg.sh new file mode 100644 index 0000000000000000000000000000000000000000..ee1f8e91111ee153d2536dd38ec097581d2425d7 --- /dev/null +++ b/scripts/validation/validate-starvector-8b-im2svg.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +python starvector/validation/run_validator.py \ +config=configs/generation/starvector-8b/im2svg.yaml \ +dataset.name svg-stack \ +model.generation_engine=hf + +python starvector/validation/run_validator.py \ +config=configs/generation/starvector-8b/im2svg.yaml \ +dataset.name svg-emoji \ +model.generation_engine=hf + +python starvector/validation/run_validator.py \ +config=configs/generation/starvector-8b/im2svg.yaml \ +dataset.name svg-fonts \ +model.generation_engine=hf + +python starvector/validation/run_validator.py \ +config=configs/generation/starvector-8b/im2svg.yaml \ +dataset.name svg-diagrams \ +model.generation_engine=hf + +python starvector/validation/run_validator.py \ +config=configs/generation/starvector-8b/im2svg.yaml \ +dataset.name svg-icons \ +model.generation_engine=hf + + diff --git a/starvector.egg-info/PKG-INFO b/starvector.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..41d3945519fa2b4e8e166e3a60965a969849d4f6 --- /dev/null +++ b/starvector.egg-info/PKG-INFO @@ -0,0 +1,495 @@ +Metadata-Version: 2.4 +Name: starvector +Version: 1.0 +Summary: Generating Scalable Vector Graphics Code from Images and Text +Project-URL: Homepage, https://starvector.github.io +Project-URL: Bug Tracker, https://github.com/joanrod/starvector/issues +Classifier: Programming Language :: Python :: 3 +Classifier: License :: OSI Approved :: Apache Software License +Requires-Python: >=3.11 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: torch==2.5.1 +Requires-Dist: torchvision==0.20.1 +Requires-Dist: transformers==4.49.0 +Requires-Dist: tokenizers==0.21.1 +Requires-Dist: sentencepiece==0.2.0 +Requires-Dist: accelerate +Requires-Dist: pydantic==2.10 +Requires-Dist: markdown2[all] +Requires-Dist: numpy<2.0.0 +Requires-Dist: scikit-learn==1.2.2 +Requires-Dist: gradio==3.36.1 +Requires-Dist: gradio_client==0.2.9 +Requires-Dist: requests +Requires-Dist: httpx==0.24.0 +Requires-Dist: uvicorn +Requires-Dist: fastapi +Requires-Dist: svgpathtools==1.6.1 +Requires-Dist: seaborn==0.12.2 +Requires-Dist: taming-transformers +Requires-Dist: lpips +Requires-Dist: cairosvg +Requires-Dist: beautifulsoup4 +Requires-Dist: webcolors +Requires-Dist: tqdm +Requires-Dist: omegaconf +Requires-Dist: open-clip-torch +Requires-Dist: noise +Requires-Dist: datasets +Requires-Dist: scikit-image +Requires-Dist: fairscale +Requires-Dist: lxml +Requires-Dist: torch-fidelity +Requires-Dist: clip-openai +Requires-Dist: scipy==1.11.1 +Requires-Dist: sentence-transformers +Requires-Dist: reportlab +Requires-Dist: svglib +Requires-Dist: Pillow +Requires-Dist: protobuf +Requires-Dist: openai +Provides-Extra: train +Requires-Dist: deepspeed; extra == "train" +Requires-Dist: ninja; extra == "train" +Requires-Dist: wandb; extra == "train" +Dynamic: license-file + +
+

💫 StarVector: Generating Scalable Vector Graphics Code from Images and Text

+ starvector + + + arXiv + + + Website + + + HF Models: StarVector + + + HF Models: StarVector + + + HF Dataset: SVG-Stack + + + HF Dataset: SVG-Bench + + + + +
+ +## 🔥 News +- March 2025: **StarVector Accepted at CVPR 2025**, + - StarVector has been accepted at CVPR 2025! [[Link](https://arxiv.org/abs/2312.11556)] + - Check out our website for more information [[Link](https://starvector.github.io/)] + - StarVector models are now available on Hugging Face Model Hub! [[Link](https://huggingface.co/starvector/starvector-1b-im2svg)] [[Link](https://huggingface.co/starvector/starvector-8b-im2svg)] + - SVGBench and SVG-Stack datasets are now available on Hugging Face Datasets Hub! [[Link](https://huggingface.co/datasets/starvector/svg-bench)] [[Link](https://huggingface.co/datasets/starvector/svg-stack)] + +## 🚀 Introduction +StarVector is a multimodal vision-language model for Scalable Vector Graphics (SVG) generation. It can be used to perform image2SVG and text2SVG generation. We pose image generation as a code generation task, using the power of multimodal VLMs + +
+ starvector +
+ +> **Abstract**: Scalable Vector Graphics (SVGs) are vital for modern image rendering due to their scalability and versatility. Previous SVG generation methods have focused on curve-based vectorization, lacking semantic understanding, often producing artifacts, and struggling with SVG primitives beyond \textit{path} curves. To address these issues, we introduce StarVector, a multimodal large language model for SVG generation. It performs image vectorization by understanding image semantics and using SVG primitives for compact, precise outputs. Unlike traditional methods, StarVector works directly in the SVG code space, leveraging visual understanding to apply accurate SVG primitives. To train StarVector, we create SVG-Stack, a diverse dataset of 2M samples that enables generalization across vectorization tasks and precise use of primitives like ellipses, polygons, and text. We address challenges in SVG evaluation, showing that pixel-based metrics like MSE fail to capture the unique qualities of vector graphics. We introduce SVG-Bench, a benchmark across 10 datasets, and 3 tasks: Image-to-SVG, Text-to-SVG generation, and diagram generation. Using this setup, StarVector achieves state-of-the-art performance, producing more compact and semantically rich SVGs. + +### Multimodal Architecture + +StarVector uses a multimodal architecture to process images and text. When performing Image-to-SVG (or image vectorization), the image is projected into visual tokens, and SVG code is generated. When performing Text-to-SVG, the model only recieves the text instruction (no image is provided), and a novel SVG is created. The LLM is based of StarCoder, which we leverage to transfer coding skills to SVG generation. + +
+ starvector +
+ +## 📖 Table of Contents +- [💿 Installation](#installation) +- [🏎️ Quick Start - Image2SVG Generation](#quick-start---image2svg-generation) +- [🎨 Models](#models) +- [📊 Datasets](#datasets---svg-bench) +- [🏋️‍♂️ Training](#training) +- [🏆 Evaluation on SVG-Bench](#validation-on-svg-benchmarks-svg-bench) +- [🧩 Demo](#starvector-demo) +- [📚 Citation](#citation) +- [📝 License](#license) + + +## Installation + +1. Clone this repository and navigate to star-vector folder +```bash +git clone https://github.com/joanrod/star-vector.git +cd star-vector +``` + +2. Install Package +```Shell +conda create -n starvector python=3.11.3 -y +conda activate starvector +pip install --upgrade pip # enable PEP 660 support +pip install -e . +``` + +3. Install additional packages for training +``` +pip install -e ".[train]" +``` + +### Upgrade to latest code base + +```Shell +git pull +pip install -e . +``` + +## Quick Start - Image2SVG Generation + +```Python +from PIL import Image +from starvector.model.starvector_arch import StarVectorForCausalLM +from starvector.data.util import process_and_rasterize_svg + +model_name = "starvector/starvector-8b-im2svg" + +starvector = StarVectorForCausalLM.from_pretrained(model_name) + +starvector.cuda() +starvector.eval() + +image_pil = Image.open('assets/examples/sample-0.png') +image = starvector.process_images([image_pil])[0].cuda() +batch = {"image": image} + +raw_svg = starvector.generate_im2svg(batch, max_length=1000)[0] +svg, raster_image = process_and_rasterize_svg(raw_svg) +``` + +### Use it from HuggingFace AutoModel + +```Python +from PIL import Image +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor +from starvector.data.util import process_and_rasterize_svg +import torch + +model_name = "starvector/starvector-8b-im2svg" + +starvector = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True) +processor = starvector.model.processor +tokenizer = starvector.model.svg_transformer.tokenizer + +starvector.cuda() +starvector.eval() + +image_pil = Image.open('assets/examples/sample-18.png') + +image = processor(image_pil, return_tensors="pt")['pixel_values'].cuda() +if not image.shape[0] == 1: + image = image.squeeze(0) +batch = {"image": image} + +raw_svg = starvector.generate_im2svg(batch, max_length=4000)[0] +svg, raster_image = process_and_rasterize_svg(raw_svg) +``` + + +## Models + +We provide [Hugging Face 🤗 model checkpoints](https://huggingface.co/collections/starvector/starvector-models-6783b22c7bd4b43d13cb5289) for image2SVG vectorization, for 💫 StarVector-8B and 💫 StarVector-1B. These are the results on SVG-Bench, using the DinoScore metric. + +| Method | SVG-Stack | SVG-Fonts | SVG-Icons | SVG-Emoji | SVG-Diagrams | +|---------------|-----------|-----------|-----------|-----------|--------------| +| AutoTrace | 0.942 | 0.954 | 0.946 | 0.975 | 0.874 | +| Potrace | 0.898 | 0.967 | 0.972 | 0.882 | 0.875 | +| VTracer | 0.954 | 0.964 | 0.940 | 0.981 | 0.882 | +| Im2Vec | 0.692 | 0.733 | 0.754 | 0.732 | - | +| LIVE | 0.934 | 0.956 | 0.959 | 0.969 | 0.870 | +| DiffVG | 0.810 | 0.821 | 0.952 | 0.814 | 0.822 | +| GPT-4-V | 0.852 | 0.842 | 0.848 | 0.850 | - | +| 💫 StarVector-1B (🤗 [Link](https://huggingface.co/starvector/starvector-1b-im2svg)) | 0.926 | 0.978 | 0.975 | 0.929 | 0.943 | +| 💫 StarVector-8B (🤗 [Link](https://huggingface.co/starvector/starvector-8b-im2svg)) | **0.966** | **0.982** | **0.984** | **0.981** | **0.959** | + +*Note*: StarVector models will not work for natural images or illustrations, as they have not been trained on those images. They excel in vectorizing icons, logotypes, technical diagrams, graphs, and charts. + +## Datasets - SVG-Bench +SVG-Bench is a benchmark for evaluating SVG generation models. It contains 10 datasets, and 3 tasks: Image-to-SVG, Text-to-SVG, and Diagram-to-SVG. + +See our [Huggingface 🤗 Dataset Collection](https://huggingface.co/collections/starvector/starvector-svg-datasets-67811204a76475be4dd66d09) + +| Dataset | Train | Val | Test | Token Length | SVG Primitives | Annotation | +|-----------------|--------|-------|------|------------------|----------------|----------------| +| SVG-Stack (🤗 [Link](https://huggingface.co/datasets/starvector/svg-stack)) | 2.1M | 108k | 5.7k | 1,822 ± 1,808 | All | [Captions](https://huggingface.co/datasets/starvector/text2svg-stack) | +| SVG-Stack_sim (🤗 [Link](https://huggingface.co/datasets/starvector/svg-stack-simple)) | 601k | 30.1k | 1.5k | 2k ± 918 | Vector path | - | +| SVG-Diagrams (🤗 [Link](https://huggingface.co/datasets/starvector/svg-diagrams)) | - | - | 472 | 3,486 ± 1,918 | All | - | +| SVG-Fonts (🤗 [Link](https://huggingface.co/datasets/starvector/svg-fonts)) | 1.8M | 91.5k | 4.8k | 2,121 ± 1,868 | Vector path | Font letter | +| SVG-Fonts_sim (🤗 [Link](https://huggingface.co/datasets/starvector/svg-fonts-simple)) | 1.4M | 71.7k | 3.7k | 1,722 ± 723 | Vector path | Font letter | +| SVG-Emoji (🤗 [Link](https://huggingface.co/datasets/starvector/svg-emoji)) | 8.7k | 667 | 668 | 2,551 ± 1,805 | All | - | +| SVG-Emoji_sim (🤗 [Link](https://huggingface.co/datasets/starvector/svg-emoji-simple)) | 580 | 57 | 96 | 2,448 ± 1,026 | Vector Path | - | +| SVG-Icons (🤗 [Link](https://huggingface.co/datasets/starvector/svg-icons)) | 80.4k | 6.2k | 2.4k | 2,449 ± 1,543 | Vector path | - | +| SVG-Icons_sim (🤗 [Link](https://huggingface.co/datasets/starvector/svg-icons-simple)) | 80,435 | 2,836 | 1,277| 2,005 ± 824 | Vector path | - | +| SVG-FIGR (🤗 [Link](https://huggingface.co/datasets/starvector/FIGR-SVG)) | 270k | 27k | 3k | 5,342 ± 2,345 | Vector path | Class, Caption | + + +>We offer a summary of statistics about the datasets used in our training and evaluation experiments. This datasets are included in SVG-Bench. The subscript _sim_ stands for the simplified version of the dataset, as required by some baselines. + +## Training + +### Confirm dependencies are installed + +```bash +pip install -e ".[train]" +``` + +### Set environment variables +We recommend setting the following environment variables: + +```bash + export HF_HOME= + export HF_TOKEN= + export WANDB_API_KEY= + export OUTPUT_DIR= +``` + +cd the root of the repository. + +```Shell +cd star-vector +``` + +### Image2SVG Pretraining (Stage 1) + +We have different training approaches for StarVector-1B and StarVector-8B. StarVector-1B can be trained using Deepspeed, while StarVector-8B requires FSDP. + +#### StarVector-1B Training + +You can use the following command to train StarVector-1B on SVG-Stack for the Image2SVG vectorization task, using Deepspeed and Accelerate + +```bash +# StarVector-1B +accelerate launch --config_file configs/accelerate/deepspeed-8-gpu.yaml starvector/train/train.py config=configs/models/starvector-1b/im2svg-stack.yaml +``` + +#### StarVector-8B Training + +You can use the following command to train StarVector-8B on SVG-Stack for the Image2SVG vectorization task, using FSDP and Accelerate. We provide the torchrun command to support multi-nodes and multi-GPUs. + +```bash +# StarVector-8B +torchrun \ + --nproc-per-node=8 \ + --nnodes=1 \ + starvector/train/train.py \ + config=configs/models/starvector-8b/im2svg-stack.yaml +``` + + +### Finetuning StarVector (Stage 2) + +After pretraining StarVector on image vectorization, we finetune it on additional SVG tasks like Text2SVG, and SVG-Bench datasets. + +#### Text2SVG Finetuning + +```bash +# StarVector-1B +accelerate launch --config_file config/accelerate/deepspeed-8-gpu.yaml starvector/train/train.py config=configs/models/starvector-1b/text2svg-stack.yaml + +# StarVector-8B +torchrun \ + --nproc-per-node=8 \ + --nnodes=1 \ + starvector/train/train.py \ + config=configs/models/starvector-8b/text2svg-stack.yaml +``` + +#### SVG-Bench Finetuning + +```bash +# StarVector-1B +accelerate launch --config_file config/accelerate/deepspeed-8-gpu.yaml starvector/train/train.py config=configs/models/starvector-1b/im2svg-{fonts,icons,emoji}.yaml + +# StarVector-8B +torchrun \ + --nproc-per-node=8 \ + --nnodes=1 \ + starvector/train/train.py \ + config=configs/models/starvector-8b/im2svg-{fonts,icons,emoji}.yaml +``` + +We also provide shell scripts in `scripts/train/*` + +## Validation on SVG Benchmarks (⭐ SVG-Bench) + +We validate StarVector on ⭐ SVG-Bench Benchmark. We provide the SVGValidator class that allows you to run StarVector using **1) the HuggingFace generation backend** or **2) the VLLM backend**. The later is substantially faster thanks to the use of Paged Attention. + +### HuggingFace Generation Backend +Let's start with the evaluation for StarVector-1B and StarVector-8B on SVG-Stack, using the HuggingFace generation backend (StarVectorHFAPIValidator). To override the input arguments, you can add cli args following the yaml file structure. + +```bash +# StarVector-1B on SVG-Stack, using the HuggingFace backend +python starvector/validation/validate.py \ +config=configs/generation/hf/starvector-1b/im2svg.yaml \ +dataset.name=starvector/svg-stack + +# StarVector-8B on SVG-Stack, using the vanilla HuggingFace generation API +python starvector/validation/validate.py \ +config=configs/generation/hf/starvector-8b/im2svg.yaml \ +dataset.name=starvector/svg-stack +``` + +### vLLM Backend + +For using the vLLM backend (StarVectorVLLMAPIValidator), first install our StarVector fork of VLLM, [here](https://github.com/starvector/vllm). + +```bash +git clone https://github.com/starvector/vllm.git +cd vllm +pip install -e . +``` + +Then, launch the using the vllm config file (it uses StarVectorVLLMValidator): + +```bash +# StarVector-1B +python starvector/validation/validate.py \ +config=configs/generation/vllm/starvector-1b/im2svg.yaml \ +dataset.name=starvector/svg-stack + +# StarVector-8B +python starvector/validation/validate.py \ +config=configs/generation/vllm/starvector-8b/im2svg.yaml \ +dataset.name=starvector/svg-stack +``` + +#### Generate using Temperature Sweep +Temperature sweep is an evaluation technique where we: +1. Generate multiple SVG candidates using different temperature values (controlling randomness in generation) +2. Evaluate each candidate using the DinoScore metric +3. Select the best performing SVG as the final output + +This approach improves result quality by exploring multiple generation possibilities, though it requires more computation time. + + +```bash +# StarVector-1B (vLLM) +python starvector/validation/run_validator.py \ +config=configs/generation/vllm/starvector-1b/im2svg.yaml \ +dataset.name=svg-stack \ +generation_params.generation_sweep=True \ +generation_params.num_generations_different_temp=5 \ +generation_params.min_temperature=0.0 \ +generation_params.max_temperature=0.5 + +# StarVector-8B (vLLM) +python starvector/validation/run_validator.py \ +config=configs/generation/vllm/starvector-8b/im2svg.yaml \ +dataset.name=svg-stack \ +generation_params.generation_sweep=True \ +generation_params.num_generations_different_temp=10 \ +generation_params.min_temperature=0.0 \ +generation_params.max_temperature=0.5 + +``` + +We provide evaluation scripts in `scripts/eval/*` + + +## StarVector Demo + +The demo provides two options for converting images to SVG code: +1. HuggingFace generation functionality +2. VLLM (recommended) - offers faster generation speed + +### Option 1: HuggingFace Generation with Gradio Web UI + +We provide a Gradio web UI for you to play with our model. + +#### Launch a controller +```Shell +python -m starvector.serve.controller --host 0.0.0.0 --port 10000 +``` + +#### Launch a gradio web server. +```Shell +python -m starvector.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --port 7000 +``` +You just launched the Gradio web interface. Now, you can open the web interface with the URL printed on the screen. You may notice that there is no model in the model list. Do not worry, as we have not launched any model worker yet. It will be automatically updated when you launch a model worker. + +#### Launch a model worker + +This is the actual *worker* that performs the inference on the GPU. Each worker is responsible for a single model specified in `--model-path`. + +```Shell +python -m starvector.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path joanrodai/starvector-1.4b +``` +Wait until the process finishes loading the model and you see "Uvicorn running on ...". Now, refresh your Gradio web UI, and you will see the model you just launched in the model list. + +You can launch as many workers as you want, and compare between different model checkpoints in the same Gradio interface. Please keep the `--controller` the same, and modify the `--port` and `--worker` to a different port number for each worker. + + +```Shell +vllm serve starvector/starvector-8b-im2svg --chat-template configs/chat-template.jinja --trust-remote-code --port 8001 --max-model-len 16000 + +python -m starvector.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port --worker http://localhost: --model-path +``` + +#### Option 2: Launch VLLM + +0. Remember to clone the starvector/vllm fork (it has modifications for starvector). + +```Shell +git clone https://github.com/starvector/vllm.git +cd vllm +pip install -e . +``` + +1. Call this to launch the VLLM endpoint + + +```Shell +vllm serve starvector/starvector-1b-im2svg --chat-template configs/chat-template.jinja --trust-remote-code --port 8000 --max-model-len 8192 +``` + +2. Create the demo for VLLM + +```Shell +python -m starvector.serve.vllm_api_gradio.controller --host 0.0.0.0 --port 10000 +python -m starvector.serve.vllm_api_gradio.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --port 7000 +python -m starvector.serve.vllm_api_gradio.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-name starvector/starvector-1b-im2svg --vllm-base-url http://localhost:8000 +``` + +3. Add more models by serving them with VLLM and calling a new model worker + +```Shell +vllm serve starvector/starvector-8b-im2svg --chat-template configs/chat-template.jinja --trust-remote-code --port 8001 --max-model-len 16384 + +python -m starvector.serve.vllm_api_gradio.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40001 --worker http://localhost:40001 --model-name starvector/starvector-8b-im2svg --vllm-base-url http://localhost:8001 +``` + +## Citation +``` +@misc{rodriguez2024starvector, + title={StarVector: Generating Scalable Vector Graphics Code from Images and Text}, + author={Juan A. Rodriguez and Abhay Puri and Shubham Agarwal and Issam H. Laradji and Pau Rodriguez and Sai Rajeswar and David Vazquez and Christopher Pal and Marco Pedersoli}, + year={2024}, + eprint={2312.11556}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2312.11556}, +} +``` + +## License +This project is licensed under the Apache License, Version 2.0 - see the [LICENSE](LICENSE) file for details. diff --git a/starvector.egg-info/SOURCES.txt b/starvector.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..0ec5ff1700fbce8f31a3c8197f4bc0c7d132e740 --- /dev/null +++ b/starvector.egg-info/SOURCES.txt @@ -0,0 +1,89 @@ +LICENSE +README.md +pyproject.toml +models/modeling_starvector.py +models/starvector/adapter.py +models/starvector/clip_model.py +models/starvector/image_encoder.py +starvector/__init__.py +starvector/adapter.py +starvector/clip_model.py +starvector/image_encoder.py +starvector/util.py +starvector.egg-info/PKG-INFO +starvector.egg-info/SOURCES.txt +starvector.egg-info/dependency_links.txt +starvector.egg-info/requires.txt +starvector.egg-info/top_level.txt +starvector/data/augmentation.py +starvector/data/base.py +starvector/data/dataset.py +starvector/data/emojisvg.py +starvector/data/figrsvg.py +starvector/data/fontsvg.py +starvector/data/iconsvg.py +starvector/data/stacksvg.py +starvector/data/util.py +starvector/metrics/base_metric.py +starvector/metrics/compute_LPIPS.py +starvector/metrics/compute_SSIM.py +starvector/metrics/compute_clip_score.py +starvector/metrics/compute_dino_score.py +starvector/metrics/compute_fid.py +starvector/metrics/compute_l2.py +starvector/metrics/count_token_length.py +starvector/metrics/inception.py +starvector/metrics/metrics.py +starvector/metrics/util.py +starvector/model/builder.py +starvector/model/starvector_arch.py +starvector/model/adapters/adapter.py +starvector/model/gpt_bigcode/__init__.py +starvector/model/gpt_bigcode/configuration_gpt_bigcode.py +starvector/model/gpt_bigcode/modeling_gpt_bigcode.py +starvector/model/image_encoder/clip_model.py +starvector/model/image_encoder/image_encoder.py +starvector/model/llm/starcoder.py +starvector/model/llm/starcoder2.py +starvector/model/models/starvector_base.py +starvector/model/models/starvector_v1.py +starvector/model/models/starvector_v2.py +starvector/serve/__init__.py +starvector/serve/constants.py +starvector/serve/controller.py +starvector/serve/conversation.py +starvector/serve/gradio_demo_with_updated_gradio.py +starvector/serve/gradio_web_server.py +starvector/serve/model_worker.py +starvector/serve/register_worker.py +starvector/serve/util.py +starvector/serve/vllm_api_gradio/controller.py +starvector/serve/vllm_api_gradio/gradio_vllm.py +starvector/serve/vllm_api_gradio/gradio_web_server.py +starvector/serve/vllm_api_gradio/model_worker.py +starvector/train/train.py +starvector/train/util.py +starvector/train/zero_to_fp32.py +starvector/validation/__init__.py +starvector/validation/starvector_hf_validator.py +starvector/validation/starvector_vllm_api_svg_validator.py +starvector/validation/starvector_vllm_svg_validator.py +starvector/validation/svg_validator_base.py +starvector/validation/validate.py +test-RL/broadcast_utils.py +test-RL/callbacks.py +test-RL/profile_callback.py +test-RL/quickstart-vllm.py +test-RL/reward.py +test-RL/reward_edge.py +test-RL/starvector_grpo_trainer.py +test-RL/test.py +test-RL/test_rl.py +test-RL/train.py +test-RL/visualize.py +test-RL/data/create_rl_dataset.py +test-RL/paper_figures/plot_metrics_step.py +test-RL/snow-launcher/launch-jobs.py +test-RL/snow-launcher/launch-overfit-job.py +test-RL/snow-launcher/launch_eval_jobs.py +test-RL/snow-launcher/re-launch-failed-jobs.py \ No newline at end of file diff --git a/starvector.egg-info/dependency_links.txt b/starvector.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/starvector.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/starvector.egg-info/requires.txt b/starvector.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..26befc6149b31972b7c7b0f0c675e96837479100 --- /dev/null +++ b/starvector.egg-info/requires.txt @@ -0,0 +1,45 @@ +torch==2.5.1 +torchvision==0.20.1 +transformers==4.49.0 +tokenizers==0.21.1 +sentencepiece==0.2.0 +accelerate +pydantic==2.10 +markdown2[all] +numpy<2.0.0 +scikit-learn==1.2.2 +gradio==3.36.1 +gradio_client==0.2.9 +requests +httpx==0.24.0 +uvicorn +fastapi +svgpathtools==1.6.1 +seaborn==0.12.2 +taming-transformers +lpips +cairosvg +beautifulsoup4 +webcolors +tqdm +omegaconf +open-clip-torch +noise +datasets +scikit-image +fairscale +lxml +torch-fidelity +clip-openai +scipy==1.11.1 +sentence-transformers +reportlab +svglib +Pillow +protobuf +openai + +[train] +deepspeed +ninja +wandb diff --git a/starvector.egg-info/top_level.txt b/starvector.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..5f31865b3c7d15a453cba8d1e3356b6eb0fbcd80 --- /dev/null +++ b/starvector.egg-info/top_level.txt @@ -0,0 +1,4 @@ +configs +models +starvector +test-RL diff --git a/starvector/__init__.py b/starvector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/starvector/__pycache__/__init__.cpython-311.pyc b/starvector/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed80659e2ed2e9fbc90fedff7f83d90d0a3f8b3c Binary files /dev/null and b/starvector/__pycache__/__init__.cpython-311.pyc differ diff --git a/starvector/adapter.py b/starvector/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5a0fe2de0a98472f67576a0fa32c47c68dedff --- /dev/null +++ b/starvector/adapter.py @@ -0,0 +1,53 @@ +import torch.nn as nn +import torch.nn.init as init +import torch + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + return x * torch.sigmoid(x) + +class Adapter(nn.Module): + def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1): + super().__init__() + self.query_length = query_length + self.dropout_prob = dropout_prob + self.adapter_norm = adapter_norm + + self.dropout = nn.Dropout(p=self.dropout_prob) + + self.c_fc = nn.Linear(input_size, input_size*2) + self.act = Swish() + self.c_proj = nn.Linear(input_size*2, output_size) + + if adapter_norm == "layer_norm": + self.norm = nn.LayerNorm([self.query_length, output_size]) + elif adapter_norm == "batch_norm": + self.norm = nn.BatchNorm1d(self.query_length) + + self.init_type = init_type.lower() + self._initialize_weights() + + def forward(self, hidden_states): + hidden_states = self.dropout(hidden_states) + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.norm(hidden_states) + return hidden_states + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + if self.init_type == "glorot": + init.xavier_uniform_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif self.init_type == "normal": + init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + else: + raise ValueError("Invalid initialization type specified.") diff --git a/starvector/clip_model.py b/starvector/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb2349dc3a5521b0a59d896025f6a7251374897 --- /dev/null +++ b/starvector/clip_model.py @@ -0,0 +1,191 @@ +# Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py + +from collections import OrderedDict +from itertools import repeat +import collections.abc +import math +import torch +import torch.nn.functional as F +from torch import nn +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +def convert_weights_to_precision(model: nn.Module, precision: torch.dtype): + """Convert applicable model parameters to the specified precision""" + + def _convert_weights_to_precision(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(precision) + if l.bias is not None: + l.bias.data = l.bias.data.to(precision) + + elif isinstance(l, (nn.MultiheadAttention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(precision) + else: + for _, p in l.named_parameters(): + p.data = p.data.to(precision) + + model.apply(_convert_weights_to_precision) + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + layernorm_dtype = self.weight.dtype + ret = super().forward(x.type(layernorm_dtype)) + return ret.type(orig_type) + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool): + super().__init__() + self.input_resolution = input_resolution + self.num_features = width + self.num_heads = heads + self.num_patches = (input_resolution // patch_size) ** 2 + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + return x diff --git a/starvector/data/augmentation.py b/starvector/data/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..e138e2fdbe257a1741a0c397c4c0339323951afb --- /dev/null +++ b/starvector/data/augmentation.py @@ -0,0 +1,250 @@ + +import numpy as np +from svgpathtools import ( + Path, Arc, CubicBezier, QuadraticBezier, + svgstr2paths) +import os +from noise import pnoise1 +import re +import matplotlib.colors as mcolors +from bs4 import BeautifulSoup +from starvector.data.util import rasterize_svg + +class SVGTransforms: + def __init__(self, transformations): + self.transformations = transformations + self.noise_std = self.transformations.get('noise_std', False) + self.noise_type = self.transformations.get('noise_type', False) + self.rotate = self.transformations.get('rotate', False) + self.shift_re = self.transformations.get('shift_re', False) + self.shift_im = self.transformations.get('shift_im', False) + self.scale = self.transformations.get('scale', False) + self.color_noise = self.transformations.get('color_noise', False) + self.p = self.transformations.get('p', 0.5) + self.color_change = self.transformations.get('color_change', False) + self.colors = self.transformations.get('colors', ['#ff0000', '#0000ff', '#000000']) + + def sample_transformations(self): + if self.rotate: + a, b = self.rotate['from'], self.rotate['to'] + rotation_angle = np.random.uniform(a, b) + self.rotation_angle = rotation_angle + + if self.shift_re or self.shift_im: + self.shift_real = np.random.uniform(self.shift_re['from'], self.shift_re['to']) + self.shift_imag = np.random.uniform(self.shift_im['from'], self.shift_im['to']) + + if self.scale: + self.scale = np.random.uniform(self.scale['from'], self.scale['to']) + + if self.color_noise: + self.color_noise_std = np.random.uniform(self.color_noise['from'], self.color_noise['to']) + + + def paths2str(self, groupped_paths, svg_opening_tag=''): + + keys_to_exclude = ['d', 'cx', 'cy', 'rx', 'ry'] + all_groups_srt = '' + for group, elements in groupped_paths.items(): + group_attributes, paths_and_attributes = elements.get('attrs', {}), elements.get('paths', []) + group_attr_str = ' '.join(f'{key}="{value}"' for key, value in group_attributes.items()) + path_strings = [] + path_str = '' + for path, attributes in paths_and_attributes: + path_attr_str = '' + d_str = path.d() + + for key, value in attributes.items(): + if key not in keys_to_exclude: + path_attr_str += f' {key}="{value}"' + + path_strings.append(f'') + path_str = "\n".join(path_strings) + if 'no_group'in group: + group_str = path_str + else: + group_str = f'\n{path_str}\n\n' + all_groups_srt += group_str + svg = f'{svg_opening_tag}\n{all_groups_srt}' + return svg + + def add_noise(self, seg): + noise_scale = np.random.uniform(self.noise_std['from'], self.noise_std['to']) + if self.noise_type == 'gaussian': + noise_sample = np.random.normal(loc=0.0, scale=noise_scale) + \ + 1j * np.random.normal(loc=0.0, scale=noise_scale) + elif self.noise_type == 'perlin': + noise_sample = complex(pnoise1(np.random.random(), octaves=2), pnoise1(np.random.random(), octaves=2))*noise_scale + + if isinstance(seg, CubicBezier): + seg.control1 = seg.control1 + noise_sample + seg.control2 = seg.control2 + noise_sample + elif isinstance(seg, QuadraticBezier): + seg.control = seg.control + noise_sample + elif isinstance(seg, Arc): + seg.radius = seg.radius + noise_sample + + + return seg + + def do_rotate(self, path, viewbox_width, viewbox_height): + if self.rotate: + new_path = path.rotated(self.rotation_angle, complex(viewbox_width/2, viewbox_height/2)) + return new_path + else: + return path + + def do_shift(self, path): + if self.shift_re or self.shift_im: + return path.translated(complex(self.shift_real, self.shift_imag)) + else: + return path + + def do_scale(self, path): + if self.scale: + return path.scaled(self.scale) + else: + return path + + def add_color_noise(self, source_color): + # Convert color to RGB + if source_color.startswith("#"): + base_color = mcolors.hex2color(source_color) + else: + base_color = mcolors.hex2color(mcolors.CSS4_COLORS.get(source_color, '#FFFFFF')) + + # Add noise to each RGB component + noise = np.random.normal(0, self.color_noise_std, 3) + noisy_color = np.clip(np.array(base_color) + noise, 0, 1) + + # Convert the RGB color back to hex + hex_color = mcolors.rgb2hex(noisy_color) + + return hex_color + + def do_color_change(self, attr): + if 'fill' in attr: + if self.color_noise or self.color_change: + fill_value = attr['fill'] + if fill_value == 'none': + new_fill_value = 'none' + else: + if self.color_noise: + new_fill_value = self.add_color_noise(fill_value) + elif self.color_change: + new_fill_value = np.random.choice(self.colors) + attr['fill'] = new_fill_value + return attr + + def clean_attributes(self, attr): + attr_out = {} + if 'fill' in attr: + attr_out = attr + elif 'style' in attr: + fill_values = re.findall('fill:[^;]+', attr['style']) + if fill_values: + fill_value = fill_values[0].replace('fill:', '').strip() + attr_out['fill'] = fill_value + else: + attr_out = attr + else: + attr_out = attr + + return attr_out + + def get_viewbox_size(self, svg): + # Try to extract viewBox attribute + match = re.search(r'viewBox="([^"]+)"', svg) + if match: + viewbox = match.group(1) + else: + # If viewBox is not found, try to extract width and height attributes + match = re.search(r'width="([^"]+)px" height="([^"]+)px"', svg) + if match: + width, height = match.groups() + viewbox = f"0 0 {width} {height}" + else: + viewbox = "0 0 256 256" # Default if neither viewBox nor width/height are found + + viewbox = [float(x) for x in viewbox.split()] + viewbox_width, viewbox_height = viewbox[2], viewbox[3] + return viewbox_width, viewbox_height + + def augment(self, svg): + if os.path.isfile(svg): + # open svg file + with open(svg, 'r') as f: + svg = f.read() + + # Sample transformations for this sample + self.sample_transformations() + + + # Parse the SVG content + soup = BeautifulSoup(svg, 'xml') + + # Get opening tag + svg_opening_tag = re.findall(']+>', svg)[0] + + viewbox_width, viewbox_height = self.get_viewbox_size(svg) + + # Get all svg parents + groups = soup.findAll() + + # Create the groups of paths based on their original tag + grouped_paths = {} + for i, g in enumerate(groups): + if g.name == 'g': + group_id = group_id = g.get('id') if g.get('id') else f'none_{i}' + group_attrs = g.attrs + + elif g.name == 'svg' or g.name == 'metadata' or g.name == 'defs': + continue + + else: + group_id = f'no_group_{i}' + group_attrs = {} + + group_svg_string = f'{svg_opening_tag}{str(g)}' + try: + paths, attributes = svgstr2paths(group_svg_string) + except: + return svg, rasterize_svg(svg) + if not paths: + continue + + paths_and_attributes = [] + + # Rotation, shift, scale, noise addition + new_paths = [] + new_attributes = [] + for path, attribute in zip(paths, attributes): + attr = self.clean_attributes(attribute) + + new_path = self.do_rotate(path, viewbox_width, viewbox_height) + new_path = self.do_shift(new_path) + new_path = self.do_scale(new_path) + + if self.noise_std: + # Add noise to path to deform svg + noisy_path = [] + for seg in new_path: + noisy_seg = self.add_noise(seg) + noisy_path.append(noisy_seg) + new_paths.append(Path(*noisy_path)) + else: + new_paths.append(new_path) + + # Color change + attr = self.do_color_change(attr) + paths_and_attributes.append((new_path, attr)) + + grouped_paths[group_id] = { + 'paths': paths_and_attributes, + 'attrs': group_attrs + } + + svg = self.paths2str(grouped_paths, svg_opening_tag) + image = rasterize_svg(svg) + + return svg, image diff --git a/starvector/data/base.py b/starvector/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..33fee34512badbba8aef048a20651c8418e5a0c1 --- /dev/null +++ b/starvector/data/base.py @@ -0,0 +1,71 @@ +from torch.utils.data import Dataset +from starvector.data.util import ImageTrainProcessor, use_placeholder, rasterize_svg +from starvector.util import instantiate_from_config +import numpy as np +from datasets import load_dataset + +class SVGDatasetBase(Dataset): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + self.split = split + self.im_size = im_size + + transforms = kwargs.get('transforms', False) + if transforms: + self.transforms = instantiate_from_config(transforms) + self.p = self.transforms.p + else: + self.transforms = None + self.p = 0.0 + + normalization = kwargs.get('normalize', False) + if normalization: + mean = tuple(normalization.get('mean', None)) + std = tuple(normalization.get('std', None)) + else: + mean = None + std = None + + self.processor = ImageTrainProcessor(size=self.im_size, mean=mean, std=std) + self.data = load_dataset(dataset_name, split=split) + + print(f"Loaded {len(self.data)} samples from {dataset_name} {split} split") + + def __len__(self): + return len(self.data_json) + + def get_svg_and_image(self, svg_str, sample_id): + do_augment = np.random.choice([True, False], p=[self.p, 1 - self.p]) + svg, image = None, None + + # Try to augment the image if conditions are met + if self.transforms is not None and do_augment: + try: + svg, image = self.transforms.augment(svg_str) + except Exception as e: + print(f"Error augmenting {sample_id} due to {str(e)}, trying to rasterize SVG") + + # If augmentation failed or wasn't attempted, try to rasterize the SVG + if svg is None or image is None: + try: + svg, image = svg_str, rasterize_svg(svg_str, self.im_size) + except Exception as e: + print(f"Error rasterizing {sample_id} due to {str(e)}, using placeholder image") + svg = use_placeholder() + image = rasterize_svg(svg, self.im_size) + + # If the image is completely white, use a placeholder image + if np.array(image).mean() == 255.0: + print(f"Image is full white, using placeholder image for {sample_id}") + svg = use_placeholder() + image = rasterize_svg(svg) + + # Process the image + if 'siglip' in self.image_processor: + image = self.processor(image).pixel_values[0] + else: + image = self.processor(image) + + return svg, image + + def __getitem__(self, idx): + raise NotImplementedError("This method should be implemented by subclasses") diff --git a/starvector/data/dataset.py b/starvector/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f58c5c1d937856f58958321cee26d09439ef68e7 --- /dev/null +++ b/starvector/data/dataset.py @@ -0,0 +1,42 @@ +import os +from starvector.data.base import SVGDatasetBase +from starvector.data.augmentation import SVGTransforms +from starvector.data.util import ImageTrainProcessor +from transformers import AutoProcessor + +class SVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs): + super().__init__(dataset_name, split, im_size, num_samples, **kwargs) + + self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']}) + select_dataset_name = kwargs.get('select_dataset_name', False) + + if select_dataset_name: + self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + self.image_processor = kwargs.get('image_processor', None) + if 'siglip' in self.image_processor: + model_name = {'siglip_512': 'google/siglip-base-patch16-512', + 'siglip_384': 'google/siglip-large-patch16-384', + 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] + self.processor = AutoProcessor.from_pretrained(model_name).image_processor + else: + self.processor = ImageTrainProcessor(size=self.im_size) + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } \ No newline at end of file diff --git a/starvector/data/emojisvg.py b/starvector/data/emojisvg.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf867d1ed188765836f8ebac1d7df714f041834 --- /dev/null +++ b/starvector/data/emojisvg.py @@ -0,0 +1,27 @@ +import os +from starvector.data.base import SVGDatasetBase + + +class EmojiSVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } \ No newline at end of file diff --git a/starvector/data/figrsvg.py b/starvector/data/figrsvg.py new file mode 100644 index 0000000000000000000000000000000000000000..32280a4d0383a06863ceb036e0287974f5d8d36a --- /dev/null +++ b/starvector/data/figrsvg.py @@ -0,0 +1,27 @@ +import os +from starvector.data.base import SVGDatasetBase +from transformers import AutoProcessor +from starvector.data.util import ImageTrainProcessor + +class FigrSVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Id'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } diff --git a/starvector/data/fontsvg.py b/starvector/data/fontsvg.py new file mode 100644 index 0000000000000000000000000000000000000000..1c90da9906f9af03b856fb751bdf879e1bad4401 --- /dev/null +++ b/starvector/data/fontsvg.py @@ -0,0 +1,28 @@ +import os +from starvector.data.base import SVGDatasetBase +from transformers import AutoProcessor +from starvector.data.util import ImageTrainProcessor + +class FontSVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } diff --git a/starvector/data/iconsvg.py b/starvector/data/iconsvg.py new file mode 100644 index 0000000000000000000000000000000000000000..45881997dd4f308e75035b5b7db0d4ad8a37ab03 --- /dev/null +++ b/starvector/data/iconsvg.py @@ -0,0 +1,38 @@ +import os +from starvector.data.base import SVGDatasetBase +from starvector.data.util import ImageTrainProcessor +from transformers import AutoProcessor + +class SVGIconsDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + self.image_processor = kwargs.get('image_processor', None) + if 'siglip' in self.image_processor: + model_name = {'siglip_512': 'google/siglip-base-patch16-512', + 'siglip_384': 'google/siglip-large-patch16-384', + 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] + self.processor = AutoProcessor.from_pretrained(model_name).image_processor + else: + self.processor = ImageTrainProcessor(size=self.im_size) + + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } diff --git a/starvector/data/stacksvg.py b/starvector/data/stacksvg.py new file mode 100644 index 0000000000000000000000000000000000000000..23d3ea76b847a437d84855b0aa9f1721dfd82c42 --- /dev/null +++ b/starvector/data/stacksvg.py @@ -0,0 +1,59 @@ +import os +from starvector.data.base import SVGDatasetBase +from starvector.data.augmentation import SVGTransforms +import random +from transformers import AutoProcessor +from starvector.data.util import ImageTrainProcessor + +text2svg_captions = [ + "Draw an SVG of ", + "Draw an SVG image of ", + "Draw an SVG picture of ", + "Generate an SVG of ", + "Create an SVG of ", + "Design an SVG of ", + "Make an SVG of ", +] + +class SVGStackDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, num_samples, **kwargs) + self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']}) + + # Text2SVG specific + self.random_caption = kwargs.get('random_caption', True) + select_dataset_name = kwargs.get('select_dataset_name', False) + if select_dataset_name: + self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + self.image_processor = kwargs.get('image_processor', None) + if self.image_processor and 'siglip' in self.image_processor: + model_name = {'siglip_512': 'google/siglip-base-patch16-512', + 'siglip_384': 'google/siglip-large-patch16-384', + 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] + self.processor = AutoProcessor.from_pretrained(model_name).image_processor + else: + self.processor = ImageTrainProcessor(size=self.im_size) + + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + + # Randomly choose between 'caption_blip' and 'caption_llava' + caption_column = random.choice(['caption_blip2', 'caption_llava']) + caption = random.choice(text2svg_captions) + self.data[idx].get(caption_column, "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption, + } diff --git a/starvector/data/util.py b/starvector/data/util.py new file mode 100644 index 0000000000000000000000000000000000000000..15254b104b1adde3c62a712628e59a8cf5f34b65 --- /dev/null +++ b/starvector/data/util.py @@ -0,0 +1,373 @@ +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import numpy as np +import matplotlib.pyplot as plt +from bs4 import BeautifulSoup +import re +from svgpathtools import svgstr2paths +import numpy as np +from PIL import Image +import cairosvg +from io import BytesIO +import numpy as np +import textwrap +import os +import base64 +import io + + + +CIRCLE_SVG = "" +VOID_SVF = "" + +def load_transforms(): + transforms = { + 'train': None, + 'eval': None + } + return transforms + +class ImageBaseProcessor(): + def __init__(self, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean=mean, std=std) + +class ImageTrainProcessor(ImageBaseProcessor): + def __init__(self, mean=None, std=None, size=224, **kwargs): + super().__init__(mean, std) + + self.size = size + + self.transform = transforms.Compose([ + transforms.Resize(self.size, interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + self.normalize + ]) + + def __call__(self, item): + return self.transform(item) + +def encode_image_base64(pil_image): + if pil_image.mode == 'RGBA': + pil_image = pil_image.convert('RGB') # Convert RGBA to RGB + buffered = io.BytesIO() + pil_image.save(buffered, format="JPEG") + base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") + return base64_image + +# -------------- Generation utils -------------- +def is_valid_svg(svg_text): + try: + svgstr2paths(svg_text) + return True + except Exception as e: + print(f"Invalid SVG: {str(e)}") + return False + +def clean_svg(svg_text, output_width=None, output_height=None): + soup = BeautifulSoup(svg_text, 'xml') # Read as soup to parse as xml + svg_bs4 = soup.prettify() # Prettify to get a string + + # Store the original signal handler + import signal + original_handler = signal.getsignal(signal.SIGALRM) + + try: + # Set a timeout to prevent hanging + def timeout_handler(signum, frame): + raise TimeoutError("SVG processing timed out") + + # Set timeout + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(5) + + # Try direct conversion without BeautifulSoup + svg_cairo = cairosvg.svg2svg(svg_bs4, output_width=output_width, output_height=output_height).decode() + + except TimeoutError: + print("SVG conversion timed out, using fallback method") + svg_cairo = """""" + finally: + # Always cancel the alarm and restore original handler, regardless of success or failure + signal.alarm(0) + signal.signal(signal.SIGALRM, original_handler) + + svg_clean = "\n".join([line for line in svg_cairo.split("\n") if not line.strip().startswith("]*\/>" + all_tags = re.findall(all_tags_pattern, svg_content) + self_closing_matches = re.findall(self_closing_pattern, svg_content) + self_closing_tags = [] + + for match in self_closing_matches: + tag = re.search(all_tags_pattern, match) + if tag: + self_closing_tags.append(tag.group(1)) + unclosed_tags = [] + + for tag in all_tags: + if all_tags.count(tag) > self_closing_tags.count(tag) + svg_content.count(''): + unclosed_tags.append(tag) + unclosed_tags = list(dict.fromkeys(unclosed_tags)) + + return unclosed_tags + + +# -------------- Plotting utils -------------- +def plot_images_side_by_side_with_metrics(image1, image2, l2_dist, CD, post_processed, out_path): + array1 = np.array(image1).astype(np.float32) + array2 = np.array(image2).astype(np.float32) + diff = np.abs(array1 - array2).astype(np.uint8) + + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + axes[0].imshow(image1) + axes[0].set_title('generated_svg') + axes[0].axis('off') + axes[1].imshow(image2) + axes[1].set_title('gt') + axes[1].axis('off') + axes[2].imshow(diff) + axes[2].set_title('Difference') + axes[2].axis('off') + plt.suptitle(f"MSE: {l2_dist:.4f}, CD: {CD:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05) + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_side_by_side(image1, image2, out_path): + array1 = np.array(image1).astype(np.float32) + array2 = np.array(image2).astype(np.float32) + diff = np.abs(array1 - array2).astype(np.uint8) + + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + axes[0].imshow(image1) + axes[0].set_title('generated_svg') + axes[0].axis('off') + axes[1].imshow(image2) + axes[1].set_title('gt') + axes[1].axis('off') + axes[2].imshow(diff) + axes[2].set_title('Difference') + axes[2].axis('off') + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_side_by_side_temperatures(samples_temp, metrics, sample_dir, outpath_filename): + # Create a plot with the original image and different temperature results + num_temps = len(samples_temp) + fig, axes = plt.subplots(2, num_temps + 1, figsize=(15, 4), gridspec_kw={'height_ratios': [10, 2]}) + + # Plot the original image + gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png') + gt_image = Image.open(gt_image_path) + axes[0, 0].imshow(gt_image) + axes[0, 0].set_title('Original') + axes[0, 0].axis('off') + axes[1, 0].text(0.5, 0.5, 'Original', horizontalalignment='center', verticalalignment='center', fontsize=16) + axes[1, 0].axis('off') + + # Plot the generated images for different temperatures and metrics + for idx, (temp, sample) in enumerate(samples_temp.items()): + gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png') + gen_image = Image.open(gen_image_path) + axes[0, idx + 1].imshow(gen_image) + axes[0, idx + 1].set_title(f'Temp {temp}') + axes[0, idx + 1].axis('off') + axes[1, idx + 1].text(0.5, 0.5, f'MSE: {metrics[temp]["mse"]:.2f}\nCD: {metrics[temp]["cd"]:.2f}', + horizontalalignment='center', verticalalignment='center', fontsize=12) + axes[1, idx + 1].axis('off') + + # Save the comparison plot + comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png') + plt.tight_layout() + plt.savefig(comparison_path) + plt.close() + +def plot_images_and_prompt(prompt, svg_raster, gt_svg_raster, out_path): + # First col shows caption, second col shows generated svg, third col shows gt svg + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + + # Split the prompt into multiple lines if it exceeds a certain length + prompt_lines = textwrap.wrap(prompt, width=30) + prompt_text = '\n'.join(prompt_lines) + + # Display the prompt in the first cell + axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) + axes[0].axis('off') + axes[1].imshow(svg_raster) + axes[1].set_title('generated_svg') + axes[1].axis('off') + axes[2].imshow(gt_svg_raster) + axes[2].set_title('gt') + axes[2].axis('off') + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_and_prompt_with_metrics(prompt, svg_raster, gt_svg_raster, clip_score, post_processed, out_path): + # First col shows caption, second col shows generated svg, third col shows gt svg + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + + # Split the prompt into multiple lines if it exceeds a certain length + prompt_lines = textwrap.wrap(prompt, width=30) + prompt_text = '\n'.join(prompt_lines) + + # Display the prompt in the first cell + axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) + axes[0].axis('off') + axes[1].imshow(svg_raster) + axes[1].set_title('generated_svg') + axes[1].axis('off') + axes[2].imshow(gt_svg_raster) + axes[2].set_title('gt') + axes[2].axis('off') + plt.suptitle(f"CLIP Score: {clip_score:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05) + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_and_prompt_temperatures(prompt, samples_temp, metrics, sample_dir, outpath_filename): + # Calculate the number of temperature variations + num_temps = len(samples_temp) + + # Create a plot with text, the original image, and different temperature results + fig, axes = plt.subplots(1, num_temps + 2, figsize=(5 + 3 * (num_temps + 1), 6)) + + # Split the prompt into multiple lines if it exceeds a certain length + prompt_lines = textwrap.wrap(prompt, width=30) + prompt_text = '\n'.join(prompt_lines) + + # Display the prompt in the first cell + axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) + axes[0].axis('off') + + # Plot the GT (ground truth) image in the second cell + gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png') + gt_image = Image.open(gt_image_path) + axes[1].imshow(gt_image) + axes[1].set_title('GT Image') + axes[1].axis('off') + + # Plot the generated images for different temperatures and display metrics + for idx, (temp, sample) in enumerate(samples_temp.items()): + gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png') + gen_image = Image.open(gen_image_path) + axes[idx + 2].imshow(gen_image) + axes[idx + 2].set_title(f'Temp {temp}') + axes[idx + 2].axis('off') + clip_score = metrics[temp]["clip_score"] + axes[idx + 2].text(0.5, -0.1, f'CLIP: {clip_score:.4f}', horizontalalignment='center', verticalalignment='center', fontsize=12, transform=axes[idx + 2].transAxes) + + # Save the comparison plot + comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png') + plt.tight_layout() + plt.savefig(comparison_path) + plt.close() + + return comparison_path + + +def plot_image_tensor(image): + import numpy as np + from PIL import Image + tensor = image[0].cpu().float() + tensor = tensor.permute(1, 2, 0) + array = (tensor.numpy() * 255).astype(np.uint8) + im = Image.fromarray(array) + im.save("tmp/output_image.jpg") + + +def plot_grid_samples(images, num_cols=5, out_path = 'grid.png'): + # Calculate the number of rows required for the grid + num_images = len(images) + num_rows = (num_images + num_cols - 1) // num_cols + + # Create a new figure + fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 8)) + + # Loop through the image files and plot them + for i, image in enumerate(images): + row = i // num_cols + col = i % num_cols + + # Open and display the image using Pillow + if type(image) == str: + img = Image.open(image) + else: + img = image + axes[row, col].imshow(img) + # axes[row, col].set_title(os.path.basename(image_file)) + axes[row, col].axis('off') + + # Remove empty subplots + for i in range(num_images, num_rows * num_cols): + row = i // num_cols + col = i % num_cols + fig.delaxes(axes[row, col]) + + # Adjust spacing between subplots + plt.tight_layout() + + # save image + plt.savefig(out_path, dpi=300) + image = Image.open(out_path) + plt.close(fig) + + return image \ No newline at end of file diff --git a/starvector/image_encoder.py b/starvector/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3f85a4cad987d91cf47285c9dd3099e6c7d6b24d --- /dev/null +++ b/starvector/image_encoder.py @@ -0,0 +1,119 @@ +import os +import torch +import torch.nn as nn +import os +from omegaconf import OmegaConf +from starvector.model.image_encoder.clip_model import convert_weights_to_precision +from starvector.data.util import ImageTrainProcessor + +class ImageEncoder(nn.Module): + def __init__(self, config, **kwargs): + super(ImageEncoder, self).__init__() + + image_size = config.image_size + torch_dtype = kwargs.get('model_precision', config.torch_dtype) + self.image_encoder_type = config.image_encoder_type + if self.image_encoder_type == 'clip': + self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size) + convert_weights_to_precision(self, torch_dtype) + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'vqgan': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'convnext': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif 'siglip' in self.image_encoder_type: + if self.image_encoder_type == 'siglip_512': + model_name = "google/siglip-base-patch16-512" + elif self.image_encoder_type == 'siglip_384': + model_name = "google/siglip-large-patch16-384" + elif self.image_encoder_type == 'siglip_256': + model_name = "google/siglip-base-patch16-256" + + from transformers import AutoProcessor, AutoModel + + self.visual_encoder = AutoModel.from_pretrained( + model_name, torch_dtype = torch_dtype + ).vision_model + + self.processor = AutoProcessor.from_pretrained( + model_name, torch_dtype = torch_dtype + ) + + def build_clip_encoder(self, image_size): + from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm + visual_encoder = VisionTransformer( + input_resolution=image_size, + patch_size=14, + width=1024, + layers=23, + heads=16, + use_grad_checkpointing=False) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def build_vqgan_encoder(self): + from taming.modules.diffusionmodules.model import Encoder + VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md + vqgan_chkp_path = VQGAN_CHECKPOINT + files_in_directory = os.listdir(vqgan_chkp_path + '/configs') + vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0] + vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file)) + visual_encoder = Encoder(**vqgan_config.model.params.ddconfig) + + # Load checkpoint weights + checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict'] + + # Create a new state_dict with modified keys + new_state_dict = {} + for key, value in checkpoint.items(): + if key.startswith('encoder.'): + new_key = key[len('encoder.'):] + new_state_dict[new_key] = value + + # Load weights + visual_encoder.load_state_dict(new_state_dict) + return visual_encoder + + def build_convnext_encoder(self): + import open_clip + model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k') + return model.visual + + def forward(self, image): + if self.image_encoder_type == 'clip': + embeds = self.visual_encoder(image) + out = self.ln_vision(embeds) + elif self.image_encoder_type == 'open-clip': + out = self.visual_encoder(image)[1] + out = self.ln_vision(out) + elif self.image_encoder_type == 'vqgan': + out = self.visual_encoder(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif self.image_encoder_type == 'convnext': + out = self.visual_encoder.trunk.forward_features(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif 'siglip' in self.image_encoder_type: + out = self.visual_encoder(image)["last_hidden_state"] + return out + + def process_images(self, images): + if self.image_encoder_type == 'clip': + res = [] + for image in images: + res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W + return res + else: + return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0) + \ No newline at end of file diff --git a/starvector/metrics/base_metric.py b/starvector/metrics/base_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d07d8472bc70220d8bc2188d41dcb20d8e92a1 --- /dev/null +++ b/starvector/metrics/base_metric.py @@ -0,0 +1,51 @@ +from starvector.metrics.util import AverageMeter +from tqdm import tqdm +import math + +class BaseMetric: + def __init__(self): + self.meter = AverageMeter() + + def reset(self): + self.meter.reset() + + def calculate_score(self, batch, update=True): + """ + Batch: {"gt_im": [PIL Image], "gen_im": [Image]} + """ + values = [] + batch_size = len(next(iter(batch.values()))) + for index in tqdm(range(batch_size)): + kwargs = {} + for key in ["gt_im", "gen_im", "gt_svg", "gen_svg", "caption"]: + if key in batch: + kwargs[key] = batch[key][index] + try: + measure = self.metric(**kwargs) + except Exception as e: + print("Error calculating metric: {}".format(e)) + continue + if math.isnan(measure): + continue + values.append(measure) + + if not values: + print("No valid values found for metric calculation.") + return float("nan") + + score = sum(values) / len(values) + if update: + self.meter.update(score, len(values)) + return self.meter.avg, values + else: + return score, values + + def metric(self, **kwargs): + """ + This method should be overridden by subclasses to provide the specific metric computation. + """ + raise NotImplementedError("The metric method must be implemented by subclasses.") + + def get_average_score(self): + return self.meter.avg + diff --git a/starvector/metrics/compute_LPIPS.py b/starvector/metrics/compute_LPIPS.py new file mode 100644 index 0000000000000000000000000000000000000000..b30c42cfdf21febf2b30a83dd51d690423294321 --- /dev/null +++ b/starvector/metrics/compute_LPIPS.py @@ -0,0 +1,56 @@ +from torchvision.transforms import ToTensor, Normalize +import torch +from torch.utils.data import DataLoader +from starvector.metrics.base_metric import BaseMetric +import lpips +from tqdm import tqdm + + +class LPIPSDistanceCalculator(BaseMetric): + def __init__(self, config=None, device='cuda'): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.model = lpips.LPIPS(net='vgg').to(device) + self.metric = self.LPIPS + self.to_tensor = ToTensor() + self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.device = device + + def LPIPS(self, tensor_image1, tensor_image2): + tensor_image1, tensor_image2 = tensor_image1.to(self.device), tensor_image2.to(self.device) + return self.model(tensor_image1, tensor_image2) + + def to_tensor_transform(self, pil_img): + return self.normalize(self.to_tensor(pil_img)) + + def collate_fn(self, batch): + gt_imgs, gen_imgs = zip(*batch) + tensor_gt_imgs = torch.stack([self.to_tensor_transform(img) for img in gt_imgs]) + tensor_gen_imgs = torch.stack([self.to_tensor_transform(img) for img in gen_imgs]) + return tensor_gt_imgs, tensor_gen_imgs + + def calculate_score(self, batch, batch_size=8, update=True): + gt_images = batch['gt_im'] + gen_images = batch['gen_im'] + + # Create DataLoader with custom collate function + data_loader = DataLoader(list(zip(gt_images, gen_images)), batch_size=batch_size, collate_fn=self.collate_fn, shuffle=False) + + values = [] + for tensor_gt_batch, tensor_gen_batch in tqdm(data_loader): + # Compute LPIPS + lpips_values = self.LPIPS(tensor_gt_batch, tensor_gen_batch) + values.extend([lpips_values.squeeze().cpu().detach().tolist()] if lpips_values.numel() == 1 else lpips_values.squeeze().cpu().detach().tolist()) + + if not values: + print("No valid values found for metric calculation.") + return float("nan") + + avg_score = sum(values) / len(values) + if update: + self.meter.update(avg_score, len(values)) + return self.meter.avg, values + else: + return avg_score, values + \ No newline at end of file diff --git a/starvector/metrics/compute_SSIM.py b/starvector/metrics/compute_SSIM.py new file mode 100644 index 0000000000000000000000000000000000000000..e0dfb75435d78261197e12276cb53ca540ea7687 --- /dev/null +++ b/starvector/metrics/compute_SSIM.py @@ -0,0 +1,35 @@ +from starvector.metrics.base_metric import BaseMetric +from skimage.metrics import structural_similarity as ssim +import numpy as np + +class SSIMDistanceCalculator(BaseMetric): + def __init__(self, config=None): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.metric = self.compute_SSIM + + def compute_SSIM(self, **kwargs): + image1 = kwargs.get('gt_im') + image2 = kwargs.get('gen_im') + win_size = kwargs.get('win_size', 11) # Increase win_size for more accuracy + channel_axis = kwargs.get('channel_axis', -1) # Default channel_axis to -1 + sigma = kwargs.get('sigma', 1.5) # Add sigma parameter for Gaussian filter + + # Convert images to numpy arrays if they aren't already + img1_np = np.array(image1) + img2_np = np.array(image2) + + # Check if images are grayscale or RGB + if len(img1_np.shape) == 3 and img1_np.shape[2] == 3: + # Compute SSIM for RGB images + score, _ = ssim(img1_np, img2_np, win_size=win_size, channel_axis=channel_axis, sigma=sigma, full=True) + else: + # Convert to grayscale if not already + if len(img1_np.shape) == 3: + img1_np = np.mean(img1_np, axis=2) + img2_np = np.mean(img2_np, axis=2) + + score, _ = ssim(img1_np, img2_np, win_size=win_size, sigma=sigma, full=True) + + return score \ No newline at end of file diff --git a/starvector/metrics/compute_clip_score.py b/starvector/metrics/compute_clip_score.py new file mode 100644 index 0000000000000000000000000000000000000000..186fb2f85bef817a53f5e00a6a82a310451fe15b --- /dev/null +++ b/starvector/metrics/compute_clip_score.py @@ -0,0 +1,55 @@ +from torchvision.transforms import ToTensor +import torch.nn.functional as F +from starvector.metrics.base_metric import BaseMetric +import torch +from torchmetrics.multimodal.clip_score import CLIPScore +from torch.utils.data import DataLoader +from tqdm import tqdm +import torchvision.transforms as transforms +from torchmetrics.functional.multimodal.clip_score import _clip_score_update + +class CLIPScoreCalculator(BaseMetric): + def __init__(self): + super().__init__() + self.class_name = self.__class__.__name__ + self.clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32") + self.clip_score.to('cuda') + + def CLIP_Score(self, images, captions): + all_scores = _clip_score_update(images, captions, self.clip_score.model, self.clip_score.processor) + return all_scores + + def collate_fn(self, batch): + gen_imgs, captions = zip(*batch) + tensor_gen_imgs = [transforms.ToTensor()(img) for img in gen_imgs] + return tensor_gen_imgs, captions + + def calculate_score(self, batch, batch_size=512, update=True): + gen_images = batch['gen_im'] + captions = batch['caption'] + + # Create DataLoader with custom collate function + data_loader = DataLoader(list(zip(gen_images, captions)), collate_fn=self.collate_fn, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + + all_scores = [] + for batch_eval in tqdm(data_loader): + images, captions = batch_eval + images = [img.to('cuda', non_blocking=True) * 255 for img in images] + list_scores = self.CLIP_Score(images, captions)[0].detach().cpu().tolist() + all_scores.extend(list_scores) + + if not all_scores: + print("No valid scores found for metric calculation.") + return float("nan"), [] + + avg_score = sum(all_scores) / len(all_scores) + if update: + self.meter.update(avg_score, len(all_scores)) + return self.meter.avg, all_scores + else: + return avg_score, all_scores + +if __name__ == '__main__': + import multiprocessing + multiprocessing.set_start_method('spawn') + # Rest of your code... \ No newline at end of file diff --git a/starvector/metrics/compute_dino_score.py b/starvector/metrics/compute_dino_score.py new file mode 100644 index 0000000000000000000000000000000000000000..99a8364c4d5fd6a69caed545e95128ef34f9c56a --- /dev/null +++ b/starvector/metrics/compute_dino_score.py @@ -0,0 +1,55 @@ +import torch +from torch.utils.data import DataLoader +from starvector.metrics.base_metric import BaseMetric +from tqdm import tqdm +from transformers import AutoModel, AutoImageProcessor +from PIL import Image +import torch.nn as nn + +class DINOScoreCalculator(BaseMetric): + def __init__(self, config=None, device='cuda'): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.model, self.processor = self.get_DINOv2_model("base") + self.model = self.model.to(device) + self.device = device + + self.metric = self.calculate_DINOv2_similarity_score + + def get_DINOv2_model(self, model_size): + if model_size == "small": + model_size = "facebook/dinov2-small" + elif model_size == "base": + model_size = "facebook/dinov2-base" + elif model_size == "large": + model_size = "facebook/dinov2-large" + else: + raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}") + return AutoModel.from_pretrained(model_size), AutoImageProcessor.from_pretrained(model_size) + + def process_input(self, image, processor): + if isinstance(image, str): + image = Image.open(image) + if isinstance(image, Image.Image): + with torch.no_grad(): + inputs = processor(images=image, return_tensors="pt").to(self.device) + outputs = self.model(**inputs) + features = outputs.last_hidden_state.mean(dim=1) + elif isinstance(image, torch.Tensor): + features = image.unsqueeze(0) if image.dim() == 1 else image + else: + raise ValueError("Input must be a file path, PIL Image, or tensor of features") + return features + + def calculate_DINOv2_similarity_score(self, **kwargs): + image1 = kwargs.get('gt_im') + image2 = kwargs.get('gen_im') + features1 = self.process_input(image1, self.processor) + features2 = self.process_input(image2, self.processor) + + cos = nn.CosineSimilarity(dim=1) + sim = cos(features1, features2).item() + sim = (sim + 1) / 2 + + return sim diff --git a/starvector/metrics/compute_fid.py b/starvector/metrics/compute_fid.py new file mode 100644 index 0000000000000000000000000000000000000000..413fca4a4c14e66a30b4aafee21220d4e02a41c0 --- /dev/null +++ b/starvector/metrics/compute_fid.py @@ -0,0 +1,145 @@ +# Refer https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html +# from torchmetrics.image.fid import FrechetInceptionDistance +from PIL import Image +from starvector.metrics.base_metric import BaseMetric +import torch +from torchvision import transforms +import clip +from torch.nn.functional import adaptive_avg_pool2d +from starvector.metrics.inception import InceptionV3 +import numpy as np +from tqdm import tqdm +from scipy import linalg +import torchvision.transforms as TF + +class FIDCalculator(BaseMetric): + def __init__(self, model_name = 'InceptionV3',): + self.class_name = self.__class__.__name__ + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model_name = model_name + if self.model_name == 'ViT-B/32': + self.dims = 512 + model, preprocess = clip.load('ViT-B/32') + + elif self.model_name == 'InceptionV3': + self.dims = 2048 + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims] + model = InceptionV3([block_idx]).to(self.device) + preprocess = TF.Compose([TF.ToTensor()]) + + self.model = model.cuda() + self.preprocess = preprocess + + def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + def get_activations(self, images): + dataset = ImageDataset(images, self.preprocess) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=False, num_workers=4) + pred_arr = np.empty((len(images), self.dims)) + start_idx = 0 + for batch in tqdm(dataloader): + batch = batch.to(self.device) + + with torch.no_grad(): + if self.model_name == 'ViT-B/32': + pred = self.model.encode_image(batch).cpu().numpy() + elif self.model_name == 'InceptionV3': + pred = self.model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred = pred.squeeze(3).squeeze(2).cpu().numpy() + pred_arr[start_idx:start_idx + pred.shape[0]] = pred + start_idx = start_idx + pred.shape[0] + + return pred_arr + + def calculate_activation_statistics(self, images): + act = self.get_activations(images) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + def pil_images_to_tensor(self, images_list): + """Convert a list of PIL Images to a torch.Tensor.""" + tensors_list = [self.preprocess(img) for img in images_list] + return torch.stack(tensors_list).cuda() # BxCxHxW format + + def calculate_score(self, batch): + m1, s1 = self.calculate_activation_statistics(batch['gt_im']) + m2, s2 = self.calculate_activation_statistics(batch['gen_im']) + fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) + return fid_value + + def reset(self): + pass + +class ImageDataset(torch.utils.data.Dataset): + def __init__(self, images, processor=None): + self.images = images + self.processor = processor + + def __len__(self): + return len(self.images) + + def __getitem__(self, i): + img = self.images[i] + img = self.processor(img) + return img \ No newline at end of file diff --git a/starvector/metrics/compute_l2.py b/starvector/metrics/compute_l2.py new file mode 100644 index 0000000000000000000000000000000000000000..ecca16a88f6d39d4b1666da4762202ce1e69d0bb --- /dev/null +++ b/starvector/metrics/compute_l2.py @@ -0,0 +1,37 @@ +from torchvision.transforms import ToTensor +import torch.nn.functional as F +from starvector.metrics.base_metric import BaseMetric +import torch + +class L2DistanceCalculator(BaseMetric): + def __init__(self, config=None, masked_l2=False): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.metric = self.l2_distance + self.masked_l2 = masked_l2 + + def l2_distance(self, **kwargs): + image1 = kwargs.get('gt_im') + image2 = kwargs.get('gen_im') + image1_tensor = ToTensor()(image1) + image2_tensor = ToTensor()(image2) + + if self.masked_l2: + # Create binary masks: 0 for white pixels, 1 for non-white pixels + mask1 = (image1_tensor != 1).any(dim=0).float() + mask2 = (image2_tensor != 1).any(dim=0).float() + + # Create a combined mask for overlapping non-white pixels + combined_mask = mask1 * mask2 + + # Apply the combined mask to both images + image1_tensor = image1_tensor * combined_mask.unsqueeze(0) + image2_tensor = image2_tensor * combined_mask.unsqueeze(0) + + # Compute mean squared error + mse = F.mse_loss(image1_tensor, image2_tensor) + return mse.item() + + + diff --git a/starvector/metrics/count_token_length.py b/starvector/metrics/count_token_length.py new file mode 100644 index 0000000000000000000000000000000000000000..8210771ec6fc7a148c069770c126610756a466e8 --- /dev/null +++ b/starvector/metrics/count_token_length.py @@ -0,0 +1,54 @@ +import torch +from torch.utils.data import DataLoader +from starvector.metrics.base_metric import BaseMetric +from tqdm import tqdm +from starvector.metrics.util import AverageMeter + +from transformers import AutoTokenizer + +class CountTokenLength(BaseMetric): + def __init__(self, config=None, device='cuda'): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b") + self.metric = self.calculate_token_length + self.meter_gt_tokens = AverageMeter() + self.meter_gen_tokens = AverageMeter() + self.meter_diff = AverageMeter() + + def calculate_token_length(self, **kwargs): + svg = kwargs.get('gt_svg') + tokens = self.tokenizer.encode(svg) + gen_svg = kwargs.get('gen_svg') + gen_tokens = self.tokenizer.encode(gen_svg) + diff = len(gen_tokens) - len(tokens) + return len(tokens), len(gen_tokens), diff + + def calculate_score(self, batch, update=None): + gt_svgs = batch['gt_svg'] + gen_svgs = batch['gen_svg'] + values = [] + for gt_svg, gen_svg in tqdm(zip(gt_svgs, gen_svgs), total=len(gt_svgs), desc="Processing SVGs"): + gt_tokens, gen_tokens, diff = self.calculate_token_length(gt_svg=gt_svg, gen_svg=gen_svg) + self.meter_gt_tokens.update(gt_tokens, 1) + self.meter_gen_tokens.update(gen_tokens, 1) + self.meter_diff.update(diff, 1) + values.append({ + 'gt_tokens': gt_tokens, + 'gen_tokens': gen_tokens, + 'diff': diff + }) + avg_score = { + 'gt_tokens': self.meter_gt_tokens.avg, + 'gen_tokens': self.meter_gen_tokens.avg, + 'diff': self.meter_diff.avg + } + if not values: + print("No valid values found for metric calculation.") + return float("nan") + + return avg_score, values + + def reset(self): + self.meter_gt_tokens.reset() + self.meter_gen_tokens.reset() + self.meter_diff.reset() diff --git a/starvector/metrics/inception.py b/starvector/metrics/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..cc56870522a1ee298abbfff9edd7243a0bf8e7dd --- /dev/null +++ b/starvector/metrics/inception.py @@ -0,0 +1,341 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=(DEFAULT_BLOCK_INDEX,), + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = _inception_v3(weights='DEFAULT') + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate(x, + size=(299, 299), + mode='bilinear', + align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def _inception_v3(*args, **kwargs): + """Wraps `torchvision.models.inception_v3`""" + try: + version = tuple(map(int, torchvision.__version__.split('.')[:2])) + except ValueError: + # Just a caution against weird version strings + version = (0,) + + # Skips default weight inititialization if supported by torchvision + # version. See https://github.com/mseitzer/pytorch-fid/issues/28. + if version >= (0, 6): + kwargs['init_weights'] = False + + # Backwards compatibility: `weights` argument was handled by `pretrained` + # argument prior to version 0.13. + if version < (0, 13) and 'weights' in kwargs: + if kwargs['weights'] == 'DEFAULT': + kwargs['pretrained'] = True + elif kwargs['weights'] is None: + kwargs['pretrained'] = False + else: + raise ValueError( + 'weights=={} not supported in torchvision {}'.format( + kwargs['weights'], torchvision.__version__ + ) + ) + del kwargs['weights'] + + return torchvision.models.inception_v3(*args, **kwargs) + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + inception = _inception_v3(num_classes=1008, + aux_logits=False, + weights=None) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(torchvision.models.inception.InceptionA): + """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(torchvision.models.inception.InceptionC): + """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(torchvision.models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(torchvision.models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) \ No newline at end of file diff --git a/starvector/metrics/metrics.py b/starvector/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..86d7cdcd6dd5798e8bb35b8fe472f687d96f3b4d --- /dev/null +++ b/starvector/metrics/metrics.py @@ -0,0 +1,127 @@ +from starvector.metrics.compute_l2 import L2DistanceCalculator +from starvector.metrics.compute_LPIPS import LPIPSDistanceCalculator +from starvector.metrics.compute_SSIM import SSIMDistanceCalculator +from starvector.metrics.compute_fid import FIDCalculator +from starvector.metrics.compute_clip_score import CLIPScoreCalculator +from starvector.data.util import rasterize_svg +from starvector.metrics.util import AverageMeter +from starvector.metrics.compute_dino_score import DINOScoreCalculator +from starvector.metrics.count_token_length import CountTokenLength +import os +from tqdm import tqdm + +class SVGMetrics: + def __init__(self, config=None): + self.class_name = self.__class__.__name__ + + default_config = { + 'L2': True, + 'Masked-L2': False, + 'LPIPS': False, + 'SSIM': False, + 'FID': False, + 'FID_clip': False, + 'CLIPScore': False, + 'CountTokenLength': False, + 'ratio_post_processed': True, + 'ratio_non_compiling': True, + 'DinoScore': True, + } + self.config = config or default_config + + self.metrics = { + 'L2': L2DistanceCalculator, + 'Masked-L2': lambda: L2DistanceCalculator(masked_l2=True), + 'LPIPS': LPIPSDistanceCalculator, + 'SSIM': SSIMDistanceCalculator, + 'FID': lambda: FIDCalculator(model_name='InceptionV3'), + 'FID_clip': lambda: FIDCalculator(model_name='ViT-B/32'), + 'CLIPScore': CLIPScoreCalculator, + 'CountTokenLength': CountTokenLength, + 'ratio_post_processed': AverageMeter, + 'ratio_non_compiling': AverageMeter, + 'DinoScore': DINOScoreCalculator, + } + + self.active_metrics = {k: v() for k, v in self.metrics.items() if self.config.get(k)} + + def reset(self): + for metric in self.active_metrics.values(): + metric.reset() + + def batch_contains_raster(self, batch): + return "gt_im" in batch and "gen_im" in batch + + def batch_contains_svg(self, batch): + return "gt_svg" in batch and "gen_svg" in batch + + def calculate_metrics(self, batch, update=True): + if not self.batch_contains_raster(batch): + batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]] + batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]] + + avg_results_dict = {} + all_results_dict = {} + + def get_sample_id(json_item): + return json_item.get('outpath_filename') or json_item.get('sample_id') + + # initialize all_results_dict + for i, json_item in enumerate(batch['json']): + sample_id = get_sample_id(json_item) + if sample_id is None: + raise ValueError(f"Could not find 'outpath_filename' or 'sample_id' in batch['json'][{i}]") + all_results_dict[sample_id] = {} + + for metric_name, metric in self.active_metrics.items(): + print(f"Calculating {metric_name}...") + + # Handle metrics that return both average and per-sample results + if metric_name in ['L2', 'Masked-L2', 'SSIM', 'CLIPScore', 'LPIPS', 'CountTokenLength', 'DinoScore']: + avg_result, list_result = metric.calculate_score(batch, update=update) + avg_results_dict[metric_name] = avg_result + + # Store individual results + for i, result in enumerate(list_result): + sample_id = get_sample_id(batch['json'][i]) + all_results_dict[sample_id][metric_name] = result + + # Handle FID metrics that only return average + elif metric_name in ['FID', 'FID_clip']: + avg_results_dict[metric_name] = metric.calculate_score(batch) + + # Handle other metrics (ratio metrics) + else: + self._handle_ratio_metric(metric_name, metric, batch, avg_results_dict, all_results_dict) + + metric.reset() + print("Average results: \n", avg_results_dict) + return avg_results_dict, all_results_dict + + def calculate_fid(self, batch): + if not self.batch_contains_raster(batch): + batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]] + batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]] + + return self.active_metrics['FID'].calculate_score(batch).item() + + def get_average_metrics(self): + metrics = {} + for metric_name, metric in self.active_metrics.items(): + if hasattr(metric, 'avg'): + metrics[metric_name] = metric.avg + elif hasattr(metric, 'get_average_score'): + metrics[metric_name] = metric.get_average_score() + return metrics + + def _handle_ratio_metric(self, metric_name, metric, batch, avg_results_dict, all_results_dict): + """Helper method to handle ratio-based metrics.""" + metric_key = metric_name.replace('avg_', '').replace('ratio_', '') + + for item in batch['json']: + sample_id = get_sample_id(item) + value = item[metric_key] + all_results_dict[sample_id][metric_name] = value + metric.update(value, 1) + + avg_results_dict[metric_name] = metric.avg \ No newline at end of file diff --git a/starvector/metrics/util.py b/starvector/metrics/util.py new file mode 100644 index 0000000000000000000000000000000000000000..1faac0ed299c21092234a19ef570584981573cc7 --- /dev/null +++ b/starvector/metrics/util.py @@ -0,0 +1,20 @@ + +# -------------- Metrics -------------- +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + \ No newline at end of file diff --git a/starvector/model/adapters/adapter.py b/starvector/model/adapters/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5a0fe2de0a98472f67576a0fa32c47c68dedff --- /dev/null +++ b/starvector/model/adapters/adapter.py @@ -0,0 +1,53 @@ +import torch.nn as nn +import torch.nn.init as init +import torch + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + return x * torch.sigmoid(x) + +class Adapter(nn.Module): + def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1): + super().__init__() + self.query_length = query_length + self.dropout_prob = dropout_prob + self.adapter_norm = adapter_norm + + self.dropout = nn.Dropout(p=self.dropout_prob) + + self.c_fc = nn.Linear(input_size, input_size*2) + self.act = Swish() + self.c_proj = nn.Linear(input_size*2, output_size) + + if adapter_norm == "layer_norm": + self.norm = nn.LayerNorm([self.query_length, output_size]) + elif adapter_norm == "batch_norm": + self.norm = nn.BatchNorm1d(self.query_length) + + self.init_type = init_type.lower() + self._initialize_weights() + + def forward(self, hidden_states): + hidden_states = self.dropout(hidden_states) + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.norm(hidden_states) + return hidden_states + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + if self.init_type == "glorot": + init.xavier_uniform_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif self.init_type == "normal": + init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + else: + raise ValueError("Invalid initialization type specified.") diff --git a/starvector/model/builder.py b/starvector/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..785a2dfdbd131853760813d1ccf1a01da5d94fa7 --- /dev/null +++ b/starvector/model/builder.py @@ -0,0 +1,49 @@ + +from starvector.model.starvector_arch import StarVectorForCausalLM, StarVectorConfig +from starvector.data.base import ImageTrainProcessor +from starvector.util import dtype_mapping +from transformers import AutoConfig + +def load_pretrained_model(model_path, device="cuda", **kwargs): + model = StarVectorForCausalLM.from_pretrained(model_path, **kwargs).to(device) + tokenizer = model.model.svg_transformer.tokenizer + image_processor = ImageTrainProcessor() + context_len = model.model.query_length + model.model.max_length + return tokenizer, model, image_processor, context_len + +def model_builder(config): + model_name = config.model.get("model_name", False) + + args = { + "task": config.model.task, + "train_image_encoder": config.training.train_image_encoder, + "ignore_mismatched_sizes": True, + "starcoder_model_name": config.model.starcoder_model_name, + "train_LLM": config.training.train_LLM, + "torch_dtype": dtype_mapping[config.training.model_precision], + "transformer_layer_cls": config.model.get("transformer_layer_cls", False), + "use_cache": config.model.use_cache, + } + if model_name: + model = StarVectorForCausalLM.from_pretrained(model_name, **args) + else: + starcoder_model_config = AutoConfig.from_pretrained(config.model.starcoder_model_name) + + starvector_config = StarVectorConfig( + max_length_train=config.model.max_length, + image_encoder_type=config.model.image_encoder_type, + use_flash_attn=config.model.use_flash_attn, + adapter_norm=config.model.adapter_norm, + starcoder_model_name=config.model.starcoder_model_name, + torch_dtype=dtype_mapping[config.training.model_precision], + num_attention_heads=starcoder_model_config.num_attention_heads, + num_hidden_layers=starcoder_model_config.num_hidden_layers, + vocab_size=starcoder_model_config.vocab_size, + hidden_size=starcoder_model_config.hidden_size, + num_kv_heads=getattr(starcoder_model_config, "num_key_value_heads", None), + ) + model = StarVectorForCausalLM(starvector_config, **args) + + return model + + \ No newline at end of file diff --git a/starvector/model/gpt_bigcode/__init__.py b/starvector/model/gpt_bigcode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1678bafc28fa8227101eb712cd41968269493c2b --- /dev/null +++ b/starvector/model/gpt_bigcode/__init__.py @@ -0,0 +1,65 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_gpt_bigcode": ["GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTBigCodeConfig"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_gpt_bigcode"] = [ + "GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTBigCodeForSequenceClassification", + "GPTBigCodeForTokenClassification", + "GPTBigCodeForCausalLM", + "GPTBigCodeModel", + "GPTBigCodePreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_gpt_bigcode import GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTBigCodeConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_gpt_bigcode import ( + GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTBigCodeForCausalLM, + GPTBigCodeForSequenceClassification, + GPTBigCodeForTokenClassification, + GPTBigCodeModel, + GPTBigCodePreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/starvector/model/gpt_bigcode/configuration_gpt_bigcode.py b/starvector/model/gpt_bigcode/configuration_gpt_bigcode.py new file mode 100644 index 0000000000000000000000000000000000000000..ececb6332f9bad2b9af559f1a586f688c2996c78 --- /dev/null +++ b/starvector/model/gpt_bigcode/configuration_gpt_bigcode.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2023 The BigCode team and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" GPTBigCode configuration""" +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + + +logger = logging.get_logger(__name__) + + + + +class GPTBigCodeConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a + GPTBigCode model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPTBigCode + [gpt_bigcode](https://huggingface.co/gpt_bigcode) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPTBigCodeModel`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new", + "gelu_pytorch_tanh"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether to call the fused softmax in float32. + scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether to scale the attention softmax in float32. + attention_type (`bool`, *optional*, defaults to `True`): + Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`). + Example: + + ```python + >>> from transformers import GPTBigCodeConfig, GPTBigCodeModel + + >>> # Initializing a GPTBigCode configuration + >>> configuration = GPTBigCodeConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPTBigCodeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt_bigcode" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1024, + n_embd=768, + n_layer=12, + n_head=12, + n_inner=None, + activation_function="gelu_pytorch_tanh", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + attention_softmax_in_fp32=True, + scale_attention_softmax_in_fp32=True, + multi_query=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 + self.multi_query = multi_query + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/starvector/model/gpt_bigcode/modeling_gpt_bigcode.py b/starvector/model/gpt_bigcode/modeling_gpt_bigcode.py new file mode 100644 index 0000000000000000000000000000000000000000..b8334b2cbe65bb20288849eb0cf9747f48a48f0d --- /dev/null +++ b/starvector/model/gpt_bigcode/modeling_gpt_bigcode.py @@ -0,0 +1,1502 @@ +# coding=utf-8 +# Copyright 2023 The Bigcode team and HuggingFace Inc. team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GPTBigCode model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2 +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from starvector.model.gpt_bigcode.configuration_gpt_bigcode import GPTBigCodeConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder" +_CONFIG_FOR_DOC = "GPTBigCodeConfig" + + + +# Fused kernels +# Use separate functions for each case because conditionals prevent kernel fusion. +# TODO: Could have better fused kernels depending on scaling, dropout and head mask. +# Is it doable without writing 32 functions? +@torch.jit.script +def upcast_masked_softmax( + x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype +): + input_dtype = x.dtype + x = x.to(softmax_dtype) * scale + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): + input_dtype = x.dtype + x = x.to(softmax_dtype) * scale + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1) + return x + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class GPTBigCodeAttention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + + self.mask_value = None + self.multi_query = config.multi_query + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.kv_heads = 1 if self.multi_query else self.num_heads + self.kv_dim = self.kv_heads * self.head_dim + self.split_size = self.embed_dim + self.is_causal = True + + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + self.layer_idx = layer_idx + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + self.scale_attention_softmax_in_fp32 = ( + config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 + ) + self.attn_pdrop = config.attn_pdrop + + if self.is_cross_attention: + if self.multi_query: + raise NotImplementedError("Multi-Query Attention not supported for cross_attention") + + self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim) + self.q_attn = nn.Linear(self.embed_dim, self.embed_dim) + else: + self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) + + self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + def _get_mask_value(self, device, dtype): + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: + self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + return self.mask_value + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + dtype = query.dtype + softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype + upcast = dtype != softmax_dtype + + unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 + scale_factor = unscale**-1 + if self.scale_attn_weights: + scale_factor /= self.head_dim**0.5 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + key_length = key.size(-1) + if self.multi_query: + # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) + # -> (batch_size, query_length, num_heads, key_length) + query_length = query_shape[1] + attn_shape = (batch_size, query_length, self.num_heads, key_length) + attn_view = (batch_size, query_length * self.num_heads, key_length) + # No copy needed for MQA 2, or when layer_past is provided. + query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + else: + # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) + # -> (batch_size, num_heads, query_length, key_length) + query_length = query_shape[2] + attn_shape = (batch_size, self.num_heads, query_length, key_length) + attn_view = (batch_size * self.num_heads, query_length, key_length) + # Always copies + query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) + # No copy when layer_past is provided. + key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) + + attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) + if query.device.type == "cpu": + # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. + # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086, + # but the fix has not been released as of pytorch version 2.0.0. + attn_weights = torch.zeros_like(attn_weights) + beta = 1 + else: + beta = 0 + attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) + + if upcast: + # Use a fused kernel to prevent a large overhead from casting and scaling. + # Sub-optimal when the key length is not a multiple of 8. + if attention_mask is None: + attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) + else: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) + else: + if attention_mask is not None: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + + # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. + attn_weights = torch.where(attention_mask, attn_weights, mask_value) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + if self.multi_query: + head_mask = head_mask.transpose(1, 2) + attn_weights = attn_weights * head_mask + + if self.multi_query: + attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) + else: + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class GPTBigCodeFlashAttention2(GPTBigCodeAttention): + """ + GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + if self.multi_query: + batch_size, query_length, _ = query.shape + query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.unsqueeze(2) + value = value.unsqueeze(2) + else: + query_length = query.shape[2] + batch_size, _, tgt, _ = key.shape + query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + + attn_dropout = self.attn_pdrop if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_attn.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2) + else: + attn_weights_reshaped = None + + outputs += (attn_weights_reshaped,) + + return outputs # a, present, (attentions) + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPTBigCodeSdpaAttention(GPTBigCodeAttention): + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + if head_mask is not None: + # The super dispatch is done in the forward. + raise ValueError( + "PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository." + ) + + scale = None + if not self.scale_attn_weights: + scale = 1 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + key.shape[-2] + + if self.multi_query: + query_length = query_shape[1] + + # SDPA requires the dimension [..., sequence_length, head_dim]. + query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) + + # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. + key = key.unsqueeze(1) + value = value.unsqueeze(1) + + # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend + # and flash attention backend (No available kernel. Aborting execution.) from the shapes + # query = [batch_size, num_heads, query_length, head_dim] + # key = [batch_size, 1, past_length, head_dim] + # value = [batch_size, 1, past_length, head_dim] + # + # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check. + if is_torch_greater_or_equal_than_2_2: + key = key.expand(-1, self.num_heads, -1, -1) + value = value.expand(-1, self.num_heads, -1, -1) + else: + query_length = query_shape[-1] + + # See the comment above. + if query.device.type == "cuda" and attention_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=self.attn_pdrop if self.training else 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, + scale=scale, + ) + + if self.multi_query: + # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) + sdpa_result = sdpa_result.transpose(1, 2) + + # Reshape is kind of expensive here, as it does a memory copy, + # but I did not manage to make away without it (logits do not match when using view) + # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) + sdpa_result = sdpa_result.reshape(query_shape) + + return sdpa_result, None + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + if not output_attentions and head_mask is None: + # Difference with the original implementation: there is no need to transpose the key here, + # as SDPA expects seq_length to be at index -2 for the key as well + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + else: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None." + ' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs + + +class GPTBigCodeMLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = nn.Linear(embed_dim, intermediate_size) + self.c_proj = nn.Linear(intermediate_size, embed_dim) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPTBIGCODE_ATTENTION_CLASSES = { + "eager": GPTBigCodeAttention, + "flash_attention_2": GPTBigCodeFlashAttention2, + "sdpa": GPTBigCodeSdpaAttention, +} + + +class GPTBigCodeBlock(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + if config.multi_query: + raise NotImplementedError("Cross-attention not implemented for MQA") + + self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation]( + config, is_cross_attention=True, layer_idx=layer_idx + ) + + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPTBigCodeMLP(self.inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.Tensor]], + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPTBigCodePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTBigCodeConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _no_split_modules = ["GPTBigCodeBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)): + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + module.c_proj.weight.data.normal_( + mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) + ) + module.c_proj._is_hf_initialized = True + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +GPT_BIGCODE_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPTBigCodeConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT_BIGCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[torch.Tensor]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.", + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeModel(GPTBigCodePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.multi_query = config.multi_query + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False + ) + + self.gradient_checkpointing = False + + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0].size(-2) + + if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_length > 0: + position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] + elif position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Self-attention mask. + query_length = input_shape[-1] + key_length = past_length + query_length + self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None + encoder_attention_mask = ( + encoder_attention_mask.bool() + if (encoder_attention_mask is not None and 0 in encoder_attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + if attention_mask is not None: + self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( + dtype=torch.bool, device=self_attention_mask.device + ) + + # MQA models: (batch_size, query_length, n_heads, key_length) + # MHA models: (batch_size, n_heads, query_length, key_length) + self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + + if self._use_sdpa and head_mask is None and not output_attentions: + # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. + dtype = self.wte.weight.dtype + min_dtype = torch.finfo(dtype).min + self_attention_mask = torch.where( + self_attention_mask, + torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), + torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device), + ) + + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + if self.multi_query: + # gpt_bigcode using MQA has the bad taste to use a causal mask with shape + # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. + self_attention_mask = self_attention_mask.transpose(1, 2) + + if query_length > 1 and attention_mask is not None and attention_mask.device.type == "cuda": + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + self_attention_mask = AttentionMaskConverter._unmask_unattended( + self_attention_mask, min_dtype=min_dtype + ) + + attention_mask = self_attention_mask + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if ( + self.config.add_cross_attention + and encoder_hidden_states is not None + and encoder_attention_mask is not None + ): + if encoder_attention_mask.dim() == 2: + encoder_attention_mask.unsqueeze(1) + assert encoder_attention_mask.dim() == 3 + encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = [] if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache: + presents.append(outputs[1]) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPTBigCodeModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + if self.config.multi_query: + past_length = past_key_values[0].shape[1] + else: + past_length = past_key_values[0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) + + +@add_start_docstrings( + """ + The GPTBigCode Model transformer with a sequence classification head on top (linear layer). + + [`GPTBigCodeForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPTBigCodeModel(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT_BIGCODE Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPTBigCodeModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).to(logits.device)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/starvector/model/image_encoder/clip_model.py b/starvector/model/image_encoder/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb2349dc3a5521b0a59d896025f6a7251374897 --- /dev/null +++ b/starvector/model/image_encoder/clip_model.py @@ -0,0 +1,191 @@ +# Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py + +from collections import OrderedDict +from itertools import repeat +import collections.abc +import math +import torch +import torch.nn.functional as F +from torch import nn +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +def convert_weights_to_precision(model: nn.Module, precision: torch.dtype): + """Convert applicable model parameters to the specified precision""" + + def _convert_weights_to_precision(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(precision) + if l.bias is not None: + l.bias.data = l.bias.data.to(precision) + + elif isinstance(l, (nn.MultiheadAttention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(precision) + else: + for _, p in l.named_parameters(): + p.data = p.data.to(precision) + + model.apply(_convert_weights_to_precision) + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + layernorm_dtype = self.weight.dtype + ret = super().forward(x.type(layernorm_dtype)) + return ret.type(orig_type) + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool): + super().__init__() + self.input_resolution = input_resolution + self.num_features = width + self.num_heads = heads + self.num_patches = (input_resolution // patch_size) ** 2 + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + return x diff --git a/starvector/model/image_encoder/image_encoder.py b/starvector/model/image_encoder/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a96dfc001fcf194611906437174f9469b440b1 --- /dev/null +++ b/starvector/model/image_encoder/image_encoder.py @@ -0,0 +1,120 @@ +import os +import torch +import torch.nn as nn +import os +from omegaconf import OmegaConf +from starvector.model.image_encoder.clip_model import convert_weights_to_precision +from starvector.data.util import ImageTrainProcessor + +class ImageEncoder(nn.Module): + def __init__(self, config, **kwargs): + super(ImageEncoder, self).__init__() + + image_size = config.image_size + torch_dtype = kwargs.get('model_precision', config.torch_dtype) + # torch_dtype = torch.float32 + self.image_encoder_type = config.image_encoder_type + if self.image_encoder_type == 'clip': + self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size) + convert_weights_to_precision(self, torch_dtype) + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'vqgan': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'convnext': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif 'siglip' in self.image_encoder_type: + if self.image_encoder_type == 'siglip_512': + model_name = "google/siglip-base-patch16-512" + elif self.image_encoder_type == 'siglip_384': + model_name = "google/siglip-large-patch16-384" + elif self.image_encoder_type == 'siglip_256': + model_name = "google/siglip-base-patch16-256" + + from transformers import AutoProcessor, AutoModel + + self.visual_encoder = AutoModel.from_pretrained( + model_name, torch_dtype = torch_dtype + ).vision_model + + self.processor = AutoProcessor.from_pretrained( + model_name, torch_dtype = torch_dtype + ) + + def build_clip_encoder(self, image_size): + from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm + visual_encoder = VisionTransformer( + input_resolution=image_size, + patch_size=14, + width=1024, + layers=23, + heads=16, + use_grad_checkpointing=False) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def build_vqgan_encoder(self): + from taming.modules.diffusionmodules.model import Encoder + VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md + vqgan_chkp_path = VQGAN_CHECKPOINT + files_in_directory = os.listdir(vqgan_chkp_path + '/configs') + vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0] + vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file)) + visual_encoder = Encoder(**vqgan_config.model.params.ddconfig) + + # Load checkpoint weights + checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict'] + + # Create a new state_dict with modified keys + new_state_dict = {} + for key, value in checkpoint.items(): + if key.startswith('encoder.'): + new_key = key[len('encoder.'):] + new_state_dict[new_key] = value + + # Load weights + visual_encoder.load_state_dict(new_state_dict) + return visual_encoder + + def build_convnext_encoder(self): + import open_clip + model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k') + return model.visual + + def forward(self, image): + if self.image_encoder_type == 'clip': + embeds = self.visual_encoder(image) + out = self.ln_vision(embeds) + elif self.image_encoder_type == 'open-clip': + out = self.visual_encoder(image)[1] + out = self.ln_vision(out) + elif self.image_encoder_type == 'vqgan': + out = self.visual_encoder(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif self.image_encoder_type == 'convnext': + out = self.visual_encoder.trunk.forward_features(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif 'siglip' in self.image_encoder_type: + out = self.visual_encoder(image)["last_hidden_state"] + return out + + def process_images(self, images): + if self.image_encoder_type == 'clip': + res = [] + for image in images: + res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W + return res + else: + return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0) + \ No newline at end of file diff --git a/starvector/model/llm/starcoder.py b/starvector/model/llm/starcoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a09b3c6d5c1bfd04dc1ea02c6fd7f25602d16e03 --- /dev/null +++ b/starvector/model/llm/starcoder.py @@ -0,0 +1,51 @@ +import torch.nn as nn +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + ) + +class StarCoderModel(nn.Module): + def __init__(self, config, **kwargs): + super(StarCoderModel, self).__init__() + + self.init_tokenizer(config.starcoder_model_name) + + self.max_length = config.max_length + model_config = AutoConfig.from_pretrained(config.starcoder_model_name, trust_remote_code=True) + kwargs = {} + kwargs['trust_remote_code'] = True + kwargs['torch_dtype'] = config.torch_dtype + + # Configure special tokens for generation + model_config.eos_token_id = self.tokenizer.eos_token_id + model_config.pad_token_id = self.tokenizer.pad_token_id + model_config.bos_token_id = self.tokenizer.bos_token_id + try: + model_config.flash_attention = config.use_flash_attn + model_config._attn_implementation = "flash_attention_2" + except ImportError: + config.use_flash_attn = False + + # model = GPTBigCodeForCausalLM(config=model_config) + model = AutoModelForCausalLM.from_pretrained(config.starcoder_model_name, config=model_config, **kwargs) + model.resize_token_embeddings(len(self.tokenizer)) + self.transformer = model + + # Prompt the model after image + self.prompt = '