Prince53 commited on
Commit
77537a2
·
verified ·
1 Parent(s): 91bbb49

Upload 144 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +21 -0
  2. Geo/GeochatP-main/.github/workflows/deploy_to_huggingface.yml +20 -0
  3. Geo/GeochatP-main/README.md +227 -0
  4. Geo/GeochatP-main/app.py +35 -0
  5. Geo/GeochatP-main/demo_images/04133.png +3 -0
  6. Geo/GeochatP-main/demo_images/04444.png +3 -0
  7. Geo/GeochatP-main/demo_images/7292.JPG +3 -0
  8. Geo/GeochatP-main/demo_images/MicrosoftTeams-image.png +3 -0
  9. Geo/GeochatP-main/demo_images/church_183.png +3 -0
  10. Geo/GeochatP-main/demo_images/train_2956_0001.png +3 -0
  11. Geo/GeochatP-main/docs/Customize_Component.md +20 -0
  12. Geo/GeochatP-main/docs/Data.md +24 -0
  13. Geo/GeochatP-main/docs/Evaluation.md +54 -0
  14. Geo/GeochatP-main/docs/LoRA.md +24 -0
  15. Geo/GeochatP-main/docs/MODEL_ZOO.md +18 -0
  16. Geo/GeochatP-main/docs/geochat_supp.pdf +3 -0
  17. Geo/GeochatP-main/geochat/__init__.py +1 -0
  18. Geo/GeochatP-main/geochat/constants.py +12 -0
  19. Geo/GeochatP-main/geochat/conversation.py +520 -0
  20. Geo/GeochatP-main/geochat/eval/batch_geochat_grounding.py +138 -0
  21. Geo/GeochatP-main/geochat/eval/batch_geochat_referring.py +132 -0
  22. Geo/GeochatP-main/geochat/eval/batch_geochat_scene.py +139 -0
  23. Geo/GeochatP-main/geochat/eval/batch_geochat_vqa.py +125 -0
  24. Geo/GeochatP-main/geochat/mm_utils.py +121 -0
  25. Geo/GeochatP-main/geochat/model/__init__.py +2 -0
  26. Geo/GeochatP-main/geochat/model/apply_delta.py +48 -0
  27. Geo/GeochatP-main/geochat/model/builder.py +149 -0
  28. Geo/GeochatP-main/geochat/model/consolidate.py +29 -0
  29. Geo/GeochatP-main/geochat/model/geochat_arch.py +262 -0
  30. Geo/GeochatP-main/geochat/model/language_model/geochat_llama.py +140 -0
  31. Geo/GeochatP-main/geochat/model/language_model/geochat_mpt.py +113 -0
  32. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-310.pyc +0 -0
  33. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-311.pyc +0 -0
  34. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-38.pyc +0 -0
  35. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-39.pyc +0 -0
  36. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/attention.cpython-310.pyc +0 -0
  37. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/attention.cpython-311.pyc +0 -0
  38. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/attention.cpython-38.pyc +0 -0
  39. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/attention.cpython-39.pyc +0 -0
  40. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/blocks.cpython-310.pyc +0 -0
  41. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/blocks.cpython-311.pyc +0 -0
  42. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/blocks.cpython-38.pyc +0 -0
  43. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/blocks.cpython-39.pyc +0 -0
  44. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/configuration_mpt.cpython-310.pyc +0 -0
  45. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/configuration_mpt.cpython-311.pyc +0 -0
  46. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/configuration_mpt.cpython-38.pyc +0 -0
  47. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/configuration_mpt.cpython-39.pyc +0 -0
  48. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/custom_embedding.cpython-310.pyc +0 -0
  49. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/custom_embedding.cpython-311.pyc +0 -0
  50. Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/custom_embedding.cpython-38.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,24 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Geo/GeochatP-main/demo_images/04133.png filter=lfs diff=lfs merge=lfs -text
37
+ Geo/GeochatP-main/demo_images/04444.png filter=lfs diff=lfs merge=lfs -text
38
+ Geo/GeochatP-main/demo_images/7292.JPG filter=lfs diff=lfs merge=lfs -text
39
+ Geo/GeochatP-main/demo_images/church_183.png filter=lfs diff=lfs merge=lfs -text
40
+ Geo/GeochatP-main/demo_images/MicrosoftTeams-image.png filter=lfs diff=lfs merge=lfs -text
41
+ Geo/GeochatP-main/demo_images/train_2956_0001.png filter=lfs diff=lfs merge=lfs -text
42
+ Geo/GeochatP-main/docs/geochat_supp.pdf filter=lfs diff=lfs merge=lfs -text
43
+ Geo/GeochatP-main/geochat/serve/examples/11760.jpg filter=lfs diff=lfs merge=lfs -text
44
+ Geo/GeochatP-main/geochat/serve/examples/11765.jpg filter=lfs diff=lfs merge=lfs -text
45
+ Geo/GeochatP-main/images/architecture.png filter=lfs diff=lfs merge=lfs -text
46
+ Geo/GeochatP-main/images/dataset.png filter=lfs diff=lfs merge=lfs -text
47
+ Geo/GeochatP-main/images/examples.png filter=lfs diff=lfs merge=lfs -text
48
+ Geo/GeochatP-main/images/grounded.jpg filter=lfs diff=lfs merge=lfs -text
49
+ Geo/GeochatP-main/images/iden.jpg filter=lfs diff=lfs merge=lfs -text
50
+ Geo/GeochatP-main/images/logo_geochat.png filter=lfs diff=lfs merge=lfs -text
51
+ Geo/GeochatP-main/images/overview2.png filter=lfs diff=lfs merge=lfs -text
52
+ Geo/GeochatP-main/images/ref_2.jpg filter=lfs diff=lfs merge=lfs -text
53
+ Geo/GeochatP-main/images/ref1.jpg filter=lfs diff=lfs merge=lfs -text
54
+ Geo/GeochatP-main/images/scene.jpg filter=lfs diff=lfs merge=lfs -text
55
+ Geo/GeochatP-main/images/teaser.png filter=lfs diff=lfs merge=lfs -text
56
+ Geo/GeochatP-main/images/vqa.jpg filter=lfs diff=lfs merge=lfs -text
Geo/GeochatP-main/.github/workflows/deploy_to_huggingface.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy to Hugging Face
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ deploy:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - name: Checkout repository
13
+ uses: actions/checkout@v3
14
+
15
+ - name: Upload to Hugging Face Space
16
+ uses: huggingface/hub-action@v1 # Use v1 instead of v1.4.4
17
+ with:
18
+ repo-id: Prince53/GeochatP # Make sure this matches your Hugging Face username and space name
19
+ repo-type: space
20
+ hf-token: ${{ secrets.HF_TOKEN }}
Geo/GeochatP-main/README.md ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GeoChat <img src="images/logo_geochat.png" height="40">: Grounded Large Vision-Language Model for Remote Sensing [CVPR-2024]
2
+ <p align="center">
3
+ <img src="https://i.imgur.com/waxVImv.png" alt="Oryx Video-ChatGPT">
4
+ </p>
5
+
6
+ #### [Kartik Kuckreja](https://www.linkedin.com/in/kartik-kuckreja-930531221/)\*, [Muhammad Sohail Danish](https://www.linkedin.com/in/muhammad-sohail-danish/)\*, [Muzammal Naseer](https://muzammal-naseer.com/), [Abhijit Das](https://sites.google.com/site/dasabhijit2048/home), [Salman Khan](https://salman-h-khan.github.io/) and [Fahad Khan](https://sites.google.com/view/fahadkhans/home)
7
+ \* Equally contributing first authors
8
+
9
+ #### **Mohamed bin Zayed University of AI, Birla Institute of Technology & Science, Australian National University, Linkoping University**
10
+
11
+ [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](https://mbzuai-oryx.github.io/GeoChat)
12
+ [![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2311.15826)
13
+ [![video](https://img.shields.io/badge/Video-Presentation-F9D371)](https://youtu.be/KOKtkkKpNDk)
14
+
15
+ ---
16
+
17
+ ## 📢 Latest Updates
18
+ - Supplementary material for the accepted paper is available here: [Supplementary](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/geochat_supp.pdf).
19
+ - **Feb-28-24**: We open source the code, model, dataset, and evaluation scripts.
20
+ - **Feb-27-24**: GeoChat has been accepted to **CVPR-24** 🎉.
21
+ - **Nov-28-23**: GeoChat paper is released [arxiv link](https://arxiv.org/abs/2311.15826). 🔥🔥
22
+ ---
23
+
24
+
25
+
26
+ ## <img src="images/logo_geochat.png" height="40">Overview
27
+
28
+ GeoChat is the first grounded Large Vision Language Model, specifically tailored to Remote Sensing(RS) scenarios. Unlike general-domain models, GeoChat excels in handling high-resolution RS imagery, employing region-level reasoning for comprehensive scene interpretation. Leveraging a newly created RS multimodal dataset, GeoChat is fine-tuned using the LLaVA-1.5 architecture. This results in robust zero-shot performance across various RS tasks, including image and region captioning, visual question answering, scene classification, visually grounded conversations, and referring object detection.
29
+
30
+ ---
31
+ ## Contents
32
+ - [Install](#install)
33
+ - [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md)
34
+ - [Dataset](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json)
35
+ - [Train](#train)
36
+ - [Evaluation](#evaluation)
37
+
38
+ ## Install
39
+
40
+ 1. Clone this repository and navigate to GeoChat folder
41
+ ```bash
42
+ git clone https://github.com/mbzuai-oryx/GeoChat.git
43
+ cd GeoChat
44
+ ```
45
+
46
+ 2. Install Package
47
+ ```Shell
48
+ conda create -n geochat python=3.10 -y
49
+ conda activate geochat
50
+ pip install --upgrade pip # enable PEP 660 support
51
+ pip install -e .
52
+ ```
53
+
54
+ 3. Install additional packages for training cases
55
+ ```
56
+ pip install ninja
57
+ pip install flash-attn --no-build-isolation
58
+ ```
59
+
60
+ ### Upgrade to latest code base
61
+
62
+ ```Shell
63
+ git pull
64
+ pip uninstall transformers
65
+ pip install -e .
66
+ ```
67
+
68
+ ## GeoChat Weights and Demo
69
+ Please check out our [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md) for all public GeoChat checkpoints, and check [LoRA.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/LoRA.md) for instructions on how to run the demo and training.
70
+
71
+ ## Train
72
+
73
+ GeoChat training consists of visual instruction tuning using GeoChat_Instruct Dataset: 318k Vicuna-generated multimodal instruction-following data, finetuned over the pretrained weights of LlaVA-v1.5.
74
+
75
+ We train GeoChat on 3 A100 GPUs with 40GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`.
76
+
77
+ ### Hyperparameters
78
+ We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.
79
+
80
+ | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
81
+ | --- | ---: | ---: | ---: | ---: | ---: |
82
+ | GeoChat-7B | 144 | 2e-5 | 1 | 2048 | 0 |
83
+
84
+ ### Pretrain (feature alignment)
85
+
86
+ We use the pretrained projector from LLaVAv1.5, which is trained on 558K subset of the LAION-CC-SBU dataset with BLIP captions. It takes around 3.5 hours for LLaVA-v1.5-7B.
87
+
88
+ - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
89
+ - `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
90
+
91
+ ### Visual Instruction Tuning
92
+
93
+ 1. Prepare data
94
+
95
+ Please download the annotation of the final mixture of our instruction tuning data [GeoChat_Instruct.json](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json), and download the split image zips from the [hugging face](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct). Save the multiple image zips in a single folder and run the following command to merge them:
96
+ ```Shell
97
+ cat images_parta* > images.zip
98
+ ```
99
+ Unzip the images.zip file to a folder and give the folder's path in [finetune_lora.sh](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
100
+
101
+ 2. Start training!
102
+
103
+ Visual instruction tuning takes more time due to the increased resolution of CLIP to 504X504. It takes around ~25 hours to finetune GeoChat-7B on 3x A100 (40G).
104
+
105
+ Training script with DeepSpeed ZeRO-3: [`finetune_lora.sh`](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
106
+
107
+ Options to note:
108
+
109
+ - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
110
+ - `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
111
+ - `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
112
+ - `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct).
113
+ -
114
+ ## Evaluation
115
+
116
+ We evaluate GeoChat on a diverse set of 7 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
117
+ See [Evaluation.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/Evaluation.md).
118
+
119
+ ## 🏆 Contributions
120
+
121
+ - **RS multimodal instruction following dataset.** We present a novel data generation pipeline, to leverage existing object detection dataset to create short descriptions of the images, followed by using Vicuna-v1.5 to create conversations using the generated text alone. Further, we add visual question-answering and scene classification abilities
122
+ using their corresponding datasets. This results in a total of 318k instruction pairs for RS domain.
123
+ - **GeoChat.** Leveraging our dataset, we finetune LLaVA-1.5 to create the remote sensing-domain vision-language model - GeoChat. Our LoRA fine-tuning is efficient and avoids forgetting the necessary context embedded in fully-tuned LLaVA model, whose MLP projection is trained to align images into the word embedding space of the LLM (Vicuna-v1.5). This allows GeoChat to retain the conversation and instruction following abilities of LLaVA and extend its domain-knowledge to remote sensing tasks.
124
+
125
+ - **Evaluation Benchmark.** We also address the lack of evaluation benchmarks to assess the capability of existing VLMs on remote-sensing conversations. To this end, we setup evaluation protocols for conversation grounding in RS, as well as a setup a suite of tasks to allow comparisons with future efforts in this direction. We show various supervised as well as zero-shot evaluations for different remote sensing tasks, including image captioning, visual question answering and scene classification to demonstrate the generalisability of GeoChat conversational VLM.
126
+
127
+ ---
128
+ ## 👁️💬 GeoChat : Grounded Large Vision-Language Model for Remote Sensing
129
+
130
+ GeoChat can accomplish multiple tasks for remote-sensing (RS) image comprehension in a unified framework. Given suitable task tokens and user queries, the model can generate visually grounded responses (text with corresponding object locations - shown on top), visual question answering on images and regions (top left and bottom right, respectively) as well as scene classification (top right) and normal natural language conversations (bottom). This makes it the first RS VLM with grounding capability.
131
+
132
+ <p align="center">
133
+ <img src="images/overview2.png" alt="GeoChat Overview">
134
+ </p>
135
+
136
+ ---
137
+
138
+ ## 🛰️ GeoChat : Architecture
139
+
140
+ An overview of GeoChat - the first grounded large vision-language model for remote sensing. Given an image input together with a user query, a visual backbone is first used to encode patch-level tokens at a higher resolution via interpolating positional encodings. A multi-layer perceptron (MLP) is used to adapt vision-tokens to language space suitable for input to a Large Language Model (Vicuna 1.5). Besides visual inputs, region locations can also be input to the model together with task-specific prompts that specify the desired task required by the user. Given this context, the LLM can generate natural language responses interleaved with corresponding object locations. GeoChat can perform multiple tasks as shown on top e.g., scene classification, image/region captioning, VQA and grounded conversations.
141
+
142
+ <p align="center">
143
+ <img src="images/architecture.png" alt="GeoChat Architectural">
144
+ </p>
145
+
146
+ ---
147
+
148
+ ## 🔍 RS Multimodal Instruction Dataset
149
+
150
+ Types of annotations available in the GeoChat instruction-set. For a given RS image, we obtain object attribute and relationship information, referring expressions and region captions along with their corresponding region annotations (shown over the image). This structured information is used to create the rich instruction-set with a total of 318k image-instruction pairs.
151
+
152
+ <p align="center">
153
+ <img src="images/dataset.png" alt="Dataset Annotation Pipeline">
154
+ </p>
155
+
156
+
157
+
158
+ ## 🤖 Qualitative results of GeoChat
159
+
160
+ Qualitative results of GeoChat. (<em>left-right</em>) Results are shown on grounding, referring object detection, and disaster/damage detection. The user can provide task-specific tokens (e.g., <strong>[grounding]</strong>) to shape model responses according to the desired behavior. The model can generate textual responses (<em>right</em>), only visual grounding (<em>center</em>) and both text and object groundings interleaved together (<em>left</em>). The model can also specify object types, object counts, object attributes and object relationships.
161
+ <p align="center">
162
+ <img src="images/examples.png" alt="Results_GCG">
163
+ </p>
164
+
165
+ ---
166
+
167
+ ## 🤖 Visual Question Answering
168
+ Qualitative examples for Visual Question Answering tasks. GeoChat is able to hold multi-turn conversations, based on various types of questions, including presence, count, complex comparisons and so on. It is able to detect objects and hold conversations against low resolution images as well.
169
+ <p align="center">
170
+ <img src="images/vqa.jpg" alt="Visual Question Answering">
171
+ </p>
172
+
173
+ ---
174
+
175
+ ## 🤖 Scene Classification
176
+ Qualitative examples for scene classification. We give the model all the classes from the dataset and ask to choose only one.
177
+ <p align="center">
178
+ <img src="images/scene.jpg" alt="Visual Question Answering">
179
+ </p>
180
+
181
+ ---
182
+
183
+ ## 🤖 Grounded Description
184
+ When asked to describe the image with the special token '[grounding]', GeoChat outputs both the description of the image as well as the bounding boxes for all the objects detected.
185
+ <p align="center">
186
+ <img src="images/grounded.jpg" alt="Grounded Description">
187
+ </p>
188
+
189
+ ---
190
+
191
+ ## 🤖 Referring Expression
192
+ When asked about an object as a referred expression, GeoChat is able to locate it and draw rotated bounding boxes around it correspondingly.
193
+ <p align="center">
194
+ <img src="images/ref1.jpg" alt="Referring Expression">
195
+ </p>
196
+ <p align="center">
197
+ <img src="images/ref_2.jpg" alt="Referring Expression">
198
+ </p>
199
+
200
+ ---
201
+
202
+ ## 🤖 Region Caption
203
+ Qualitative examples for region-based captioning. Given a bounding box, GeoChat is able to provide brief descriptions about the area or the object covered by the bounding box.
204
+ <p align="center">
205
+ <img src="images/iden.jpg" alt="Region Caption">
206
+ </p>
207
+
208
+ ---
209
+
210
+ ## 📜 Citation
211
+ ```bibtex
212
+ @article{kuckreja2023geochat,
213
+ title={GeoChat: Grounded Large Vision-Language Model for Remote Sensing},
214
+ author={Kuckreja, Kartik and Danish, Muhammad S. and Naseer, Muzammal and Das, Abhijit and Khan, Salman and Khan, Fahad S.},
215
+ journal={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
216
+ year={2024}
217
+ }
218
+ ```
219
+ ## 🙏 Acknowledgement
220
+ We are thankful to LLaVA and Vicuna for releasing their models and code as open-source contributions.
221
+
222
+ ---
223
+ [<img src="images/IVAL_logo.png" width="200" height="100">](https://www.ival-mbzuai.com)
224
+ [<img src="images/Oryx_logo.png" width="100" height="100">](https://github.com/mbzuai-oryx)
225
+ [<img src="images/MBZUAI_logo.png" width="360" height="85">](https://mbzuai.ac.ae)
226
+ Fixing Hugging Face deployment
227
+
Geo/GeochatP-main/app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+
6
+ # Load model
7
+ class MyModel(torch.nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ # Define layers here
11
+
12
+ def forward(self, x):
13
+ # Forward pass
14
+ return x
15
+
16
+ model = MyModel()
17
+ model.load_state_dict(torch.load("model.pth"))
18
+ model.eval()
19
+
20
+ # Define image preprocessing
21
+ transform = transforms.Compose([
22
+ transforms.Resize((224, 224)),
23
+ transforms.ToTensor(),
24
+ ])
25
+
26
+ # Define prediction function
27
+ def predict(image):
28
+ image = transform(image).unsqueeze(0) # Add batch dimension
29
+ with torch.no_grad():
30
+ output = model(image)
31
+ return output.numpy().tolist()
32
+
33
+ # Create Gradio interface
34
+ iface = gr.Interface(fn=predict, inputs=gr.Image(), outputs="json")
35
+ iface.launch()
Geo/GeochatP-main/demo_images/04133.png ADDED

Git LFS Details

  • SHA256: d554202729a40d67eb39fc38759e196ca628cf6f7f3c2679b075fcb9a9f52e80
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
Geo/GeochatP-main/demo_images/04444.png ADDED

Git LFS Details

  • SHA256: 9c6ec5f948638f44dd80814bfe20205a5c57ec9d01ebf3fcaf50e5a37f2067f5
  • Pointer size: 131 Bytes
  • Size of remote file: 969 kB
Geo/GeochatP-main/demo_images/7292.JPG ADDED

Git LFS Details

  • SHA256: 5a16bbdb6f4743afac0dc3ea914003c5609ff735966d88a3f6cfccddf837baaf
  • Pointer size: 132 Bytes
  • Size of remote file: 3.41 MB
Geo/GeochatP-main/demo_images/MicrosoftTeams-image.png ADDED

Git LFS Details

  • SHA256: 1b20fb8c3e814b8bb1895079ff1c02d0234dc21607ddada6c7cc0ba89f45e479
  • Pointer size: 131 Bytes
  • Size of remote file: 273 kB
Geo/GeochatP-main/demo_images/church_183.png ADDED

Git LFS Details

  • SHA256: 225af61e9e76edbdfe995b6f8d9d1a07255e05d3b558653db687c800d293bdde
  • Pointer size: 131 Bytes
  • Size of remote file: 686 kB
Geo/GeochatP-main/demo_images/train_2956_0001.png ADDED

Git LFS Details

  • SHA256: 2bcd2e7cd60fb52bd786f7cc7705ea6ddb68238a17fc2da6bd8495d307f11ae9
  • Pointer size: 131 Bytes
  • Size of remote file: 680 kB
Geo/GeochatP-main/docs/Customize_Component.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Customize Components in GeoChat
2
+
3
+ This is an initial guide on how to replace the LLMs, visual encoders, etc. with your choice of components.
4
+
5
+ ## LLM
6
+
7
+ It is quite simple to swap out LLaMA to any other LLMs. You can refer to our implementation of [`GeoChat_llama.py`](https://github.com/mbzuai-oryx/GeoChat/blob/main/geochat/model/language_model/geochat_llama.py) for an example of how to replace the LLM.
8
+
9
+ Although it may seem that it still needs ~100 lines of code, most of them are copied from the original `llama.py` from HF. The only part that is different is to insert some lines for processing the multimodal inputs.
10
+
11
+ In `forward` function, you can see that we call `self.prepare_inputs_labels_for_multimodal` to process the multimodal inputs. This function is defined in `GeoChatMetaForCausalLM` and you just need to insert it into the `forward` function of your LLM.
12
+
13
+ In `prepare_inputs_for_generation` function, you can see that we add `images` to the `model_inputs`. This is because we need to pass the images to the LLM during generation.
14
+
15
+ These are basically all the changes you need to make to replace the LLM.
16
+
17
+ ## Visual Encoder
18
+
19
+ You can check out [`clip_encoder.py`](https://github.com/haotian-liu/LLaVA/blob/main/llava/model/multimodal_encoder/clip_encoder.py) on how we implement the CLIP visual encoder.
20
+
Geo/GeochatP-main/docs/Data.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Finetuning Data
2
+ We use GeoChat-Instruct to finetune our model. The instruction following dataset is present in GeoChat_Instruct.json and the images are present in the [huggingface repo](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct). The images are split into multiple files. Download the separate files in the same folder and run the following script to merge them.
3
+
4
+ ```Shell
5
+ cat images_parta* > images.zip
6
+ ```
7
+
8
+ Unzip the images in a folder and provide the folder path in training and evaluation scripts.
9
+
10
+ | Data file name | Size |
11
+ | --- | ---: |
12
+ | [GeoChat_Instruct](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json) | 263 MB |
13
+
14
+ ## Pretraining Dataset
15
+ We use the same pretraining dataset as of LlaVA-v1.5.
16
+ The pretraining dataset used in this release is a subset of CC-3M dataset, filtered with a more balanced concept coverage distribution. Please see [here](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K) for a detailed description of the dataset structure and how to download the images.
17
+
18
+ If you already have CC-3M dataset on your disk, the image names follow this format: `GCC_train_000000000.jpg`. You may edit the `image` field correspondingly if necessary.
19
+
20
+ | Data | Chat File | Meta Data | Size |
21
+ | --- | --- | --- | ---: |
22
+ | CC-3M Concept-balanced 595K | [chat.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/chat.json) | [metadata.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/metadata.json) | 211 MB
23
+ | LAION/CC/SBU BLIP-Caption Concept-balanced 558K | [blip_laion_cc_sbu_558k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/blob/main/blip_laion_cc_sbu_558k.json) | [metadata.json](#) | 181 MB
24
+
Geo/GeochatP-main/docs/Evaluation.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ We evaluate GeoChat on a variety of tasks, including scene classification, region captioning, visual grounding, grounding description and VQA.
4
+ Converted files in the input format for GeoChat are available at [GeoChat-Bench](https://huggingface.co/datasets/MBZUAI/GeoChat-Bench/tree/main)
5
+
6
+
7
+ Below we provide a general guideline for evaluating datasets.
8
+
9
+ 1. LRBEN/HRBEN.
10
+ Images and ground truth for evaluation need to be downloaded from the following sources: [LRBEN](https://zenodo.org/records/6344334), [HRBEN](https://zenodo.org/records/6344367)
11
+ Give the path to the extracted image folder in the evaluation script. We add the following text after each question during our evaluation.
12
+ ```
13
+ <question>
14
+ Answer the question using a single word or phrase.
15
+ ```
16
+ ```Shell
17
+ python geochat/eval/batch_geochat_vqa.py \
18
+ --model-path /path/to/model \
19
+ --question-file path/to/jsonl/file \
20
+ --answer-file path/to/output/jsonl/file \
21
+ --image_folder path/to/image/folder/
22
+ ```
23
+ 2. Scene Classification.
24
+ Download the images from the following sources, [UCmerced](http://weegee.vision.ucmerced.edu/datasets/landuse.html), [AID](https://drive.google.com/drive/folders/1-1D9DrYYWMGuuxx-qcvIIOV1oUkAVf-M). We add the following text after each question during our evaluation.
25
+ ```
26
+ <question>
27
+ Classify the image from the following classes. Answer in one word or a short phrase.
28
+ ```
29
+ ```Shell
30
+ python geochat/eval/batch_geochat_scene.py \
31
+ --model-path /path/to/model \
32
+ --question-file path/to/jsonl/file \
33
+ --answer-file path/to/output/jsonl/file \
34
+ --image_folder path/to/image/folder/
35
+ ```
36
+
37
+ 3. Region-Captioning/Visual grounding.
38
+
39
+ The evaluation images are present in the image.zip folder in [GeoChat_Instruct](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/images.zip).
40
+ ```Shell
41
+ python geochat/eval/batch_geochat_grounding.py \
42
+ --model-path /path/to/model \
43
+ --question-file path/to/jsonl/file \
44
+ --answer-file path/to/output/jsonl/file \
45
+ --image_folder path/to/image/folder/
46
+ ```
47
+
48
+ ```Shell
49
+ python geochat/eval/batch_geochat_referring.py \
50
+ --model-path /path/to/model \
51
+ --question-file path/to/jsonl/file \
52
+ --answer-file path/to/output/jsonl/file \
53
+ --image_folder path/to/image/folder/
54
+ ```
Geo/GeochatP-main/docs/LoRA.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Demo (Web UI)
3
+ You need GeoChat-7B to run the demo locally. Download the model from [GeoChat-7B](https://huggingface.co/MBZUAI/geochat-7B). After loading the model, run this command by giving the model path to launch the gradio demo.
4
+ #### Launch the demo
5
+ ```Shell
6
+ python geochat_demo.py --model-path /path/to/model
7
+ ```
8
+
9
+ ## Training
10
+
11
+ Please see sample training scripts for [LoRA](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh)
12
+
13
+ We provide sample DeepSpeed configs, [`zero3.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero3.json) is more like PyTorch FSDP, and [`zero3_offload.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero3_offload.json) can further save memory consumption by offloading parameters to CPU. `zero3.json` is usually faster than `zero3_offload.json` but requires more GPU memory, therefore, we recommend trying `zero3.json` first, and if you run out of GPU memory, try `zero3_offload.json`. You can also tweak the `per_device_train_batch_size` and `gradient_accumulation_steps` in the config to save memory, and just to make sure that `per_device_train_batch_size` and `gradient_accumulation_steps` remains the same.
14
+
15
+ If you are having issues with ZeRO-3 configs, and there are enough VRAM, you may try [`zero2.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero2.json). This consumes slightly more memory than ZeRO-3, and behaves more similar to PyTorch FSDP, while still supporting parameter-efficient tuning.
16
+
17
+ ## Create Merged Checkpoints
18
+
19
+ ```Shell
20
+ python scripts/merge_lora_weights.py \
21
+ --model-path /path/to/lora_model \
22
+ --model-base /path/to/base_model \
23
+ --save-model-path /path/to/merge_model
24
+ ```
Geo/GeochatP-main/docs/MODEL_ZOO.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Zoo
2
+
3
+ | Base LLM | Vision Encoder | Pretrain Data | Pretraining schedule | Finetuning Data | Finetuning schedule | Download |
4
+ |----------|----------------|---------------|----------------------|-----------------|--------------------|------------------
5
+ | Vicuna-13B-v1.3 | CLIP-L-336px(extended to 504) | LCS-558K | 1e | Geochat_Instruct | proj-1e, lora-1e | [LoRA-Merged](https://huggingface.co/MBZUAI/geochat-7B) |
6
+
7
+ ## Projector weights
8
+ We use the projector from LlaVA-1.5 for initialization. [Link](https://huggingface.co/liuhaotian/llava-v1.5-7b-lora)
9
+
10
+ **NOTE**: When you use our pretrained projector for visual instruction tuning, it is very important to **use the same base LLM and vision encoder** as the one we used for pretraining the projector. Otherwise, the performance will be very bad.
11
+
12
+ When using these projector weights to instruction tune your LMM, please make sure that these options are correctly set as follows,
13
+
14
+ ```Shell
15
+ --mm_use_im_start_end False
16
+ --mm_use_im_patch_token False
17
+ ```
18
+
Geo/GeochatP-main/docs/geochat_supp.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc9b5d3df4af5c06e59fc2258332e00b779c55d441405cb5a7fd7997d29b63fe
3
+ size 4839915
Geo/GeochatP-main/geochat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import GeoChatLlamaForCausalLM
Geo/GeochatP-main/geochat/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
Geo/GeochatP-main/geochat/conversation.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ from PIL import Image
5
+ from threading import Thread
6
+
7
+ from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
8
+ # from llava.conversation import conv_templates, SeparatorStyle
9
+ # from llava.model.builder import load_pretrained_model
10
+ from geochat.utils import disable_torch_init
11
+ from geochat.mm_utils import process_images_demo, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
12
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer,TextStreamer
13
+ import torch
14
+ import dataclasses
15
+ from enum import auto, Enum
16
+ from typing import List, Tuple, Any
17
+
18
+
19
+ class SeparatorStyle(Enum):
20
+ """Different separator style."""
21
+ SINGLE = auto()
22
+ TWO = auto()
23
+ MPT = auto()
24
+ PLAIN = auto()
25
+ LLAMA_2 = auto()
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class Conversation:
30
+ """A class that keeps all conversation history."""
31
+ system: str
32
+ roles: List[str]
33
+ messages: List[List[str]]
34
+ offset: int
35
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
36
+ sep: str = "###"
37
+ sep2: str = None
38
+ version: str = "Unknown"
39
+
40
+ skip_next: bool = False
41
+
42
+ def get_prompt(self):
43
+ messages = self.messages
44
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
45
+ messages = self.messages.copy()
46
+ init_role, init_msg = messages[0].copy()
47
+ init_msg = init_msg[0].replace("<image>", "").strip()
48
+ if 'mmtag' in self.version:
49
+ messages[0] = (init_role, init_msg)
50
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
51
+ messages.insert(1, (self.roles[1], "Received."))
52
+ else:
53
+ messages[0] = (init_role, "<image>\n" + init_msg)
54
+
55
+ if self.sep_style == SeparatorStyle.SINGLE:
56
+ ret = self.system + self.sep
57
+ for role, message in messages:
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + self.sep
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.TWO:
65
+ seps = [self.sep, self.sep2]
66
+ ret = self.system + seps[0]
67
+ for i, (role, message) in enumerate(messages):
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _ = message
71
+ ret += role + ": " + message + seps[i % 2]
72
+ else:
73
+ ret += role + ":"
74
+ elif self.sep_style == SeparatorStyle.MPT:
75
+ ret = self.system + self.sep
76
+ for role, message in messages:
77
+ if message:
78
+ if type(message) is tuple:
79
+ message, _, _ = message
80
+ ret += role + message + self.sep
81
+ else:
82
+ ret += role
83
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
84
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
85
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
86
+ ret = ""
87
+
88
+ for i, (role, message) in enumerate(messages):
89
+ if i == 0:
90
+ assert message, "first message should not be none"
91
+ assert role == self.roles[0], "first message should come from user"
92
+ if message:
93
+ if type(message) is tuple:
94
+ message, _, _ = message
95
+ if i == 0: message = wrap_sys(self.system) + message
96
+ if i % 2 == 0:
97
+ message = wrap_inst(message)
98
+ ret += self.sep + message
99
+ else:
100
+ ret += " " + message + " " + self.sep2
101
+ else:
102
+ ret += ""
103
+ ret = ret.lstrip(self.sep)
104
+ elif self.sep_style == SeparatorStyle.PLAIN:
105
+ seps = [self.sep, self.sep2]
106
+ ret = self.system
107
+ for i, (role, message) in enumerate(messages):
108
+ if message:
109
+ if type(message) is tuple:
110
+ message, _, _ = message
111
+ ret += message + seps[i % 2]
112
+ else:
113
+ ret += ""
114
+ else:
115
+ raise ValueError(f"Invalid style: {self.sep_style}")
116
+
117
+ return ret
118
+
119
+ def append_message(self, role, message):
120
+ self.messages.append([role, message])
121
+
122
+ def get_images(self, return_pil=False):
123
+ images = []
124
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
125
+ if i % 2 == 0:
126
+ if type(msg) is tuple:
127
+ import base64
128
+ from io import BytesIO
129
+ from PIL import Image
130
+ msg, image, image_process_mode = msg
131
+ if image_process_mode == "Pad":
132
+ def expand2square(pil_img, background_color=(122, 116, 104)):
133
+ width, height = pil_img.size
134
+ if width == height:
135
+ return pil_img
136
+ elif width > height:
137
+ result = Image.new(pil_img.mode, (width, width), background_color)
138
+ result.paste(pil_img, (0, (width - height) // 2))
139
+ return result
140
+ else:
141
+ result = Image.new(pil_img.mode, (height, height), background_color)
142
+ result.paste(pil_img, ((height - width) // 2, 0))
143
+ return result
144
+ image = expand2square(image)
145
+ elif image_process_mode in ["Default", "Crop"]:
146
+ pass
147
+ elif image_process_mode == "Resize":
148
+ image = image.resize((336, 336))
149
+ else:
150
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
151
+ max_hw, min_hw = max(image.size), min(image.size)
152
+ aspect_ratio = max_hw / min_hw
153
+ max_len, min_len = 800, 400
154
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
155
+ longest_edge = int(shortest_edge * aspect_ratio)
156
+ W, H = image.size
157
+ if longest_edge != max(image.size):
158
+ if H > W:
159
+ H, W = longest_edge, shortest_edge
160
+ else:
161
+ H, W = shortest_edge, longest_edge
162
+ image = image.resize((W, H))
163
+ if return_pil:
164
+ images.append(image)
165
+ else:
166
+ buffered = BytesIO()
167
+ image.save(buffered, format="PNG")
168
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
169
+ images.append(img_b64_str)
170
+ return images
171
+
172
+ def to_gradio_chatbot(self):
173
+ ret = []
174
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
175
+ if i % 2 == 0:
176
+ if type(msg) is tuple:
177
+ import base64
178
+ from io import BytesIO
179
+ msg, image, image_process_mode = msg
180
+ max_hw, min_hw = max(image.size), min(image.size)
181
+ aspect_ratio = max_hw / min_hw
182
+ max_len, min_len = 800, 400
183
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
184
+ longest_edge = int(shortest_edge * aspect_ratio)
185
+ W, H = image.size
186
+ if H > W:
187
+ H, W = longest_edge, shortest_edge
188
+ else:
189
+ H, W = shortest_edge, longest_edge
190
+ image = image.resize((W, H))
191
+ buffered = BytesIO()
192
+ image.save(buffered, format="JPEG")
193
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
194
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
195
+ msg = img_str + msg.replace('<image>', '').strip()
196
+ ret.append([msg, None])
197
+ else:
198
+ ret.append([msg, None])
199
+ else:
200
+ ret[-1][-1] = msg
201
+ return ret
202
+
203
+ def copy(self):
204
+ return Conversation(
205
+ system=self.system,
206
+ roles=self.roles,
207
+ messages=[[x, y] for x, y in self.messages],
208
+ offset=self.offset,
209
+ sep_style=self.sep_style,
210
+ sep=self.sep,
211
+ sep2=self.sep2,
212
+ version=self.version)
213
+
214
+ def dict(self):
215
+ if len(self.get_images()) > 0:
216
+ return {
217
+ "system": self.system,
218
+ "roles": self.roles,
219
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
220
+ "offset": self.offset,
221
+ "sep": self.sep,
222
+ "sep2": self.sep2,
223
+ }
224
+ return {
225
+ "system": self.system,
226
+ "roles": self.roles,
227
+ "messages": self.messages,
228
+ "offset": self.offset,
229
+ "sep": self.sep,
230
+ "sep2": self.sep2,
231
+ }
232
+
233
+
234
+ conv_vicuna_v0 = Conversation(
235
+ system="A chat between a curious human and an artificial intelligence assistant. "
236
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
237
+ roles=("Human", "Assistant"),
238
+ messages=(
239
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
240
+ ("Assistant",
241
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
242
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
243
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
244
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
245
+ "renewable and non-renewable energy sources:\n"
246
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
247
+ "energy sources are finite and will eventually run out.\n"
248
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
249
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
250
+ "and other negative effects.\n"
251
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
252
+ "have lower operational costs than non-renewable sources.\n"
253
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
254
+ "locations than non-renewable sources.\n"
255
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
256
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
257
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
258
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
259
+ ),
260
+ offset=2,
261
+ sep_style=SeparatorStyle.SINGLE,
262
+ sep="###",
263
+ )
264
+
265
+ conv_vicuna_v1 = Conversation(
266
+ system="A chat between a curious user and an artificial intelligence assistant. "
267
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
268
+ roles=("USER", "ASSISTANT"),
269
+ version="v1",
270
+ messages=(),
271
+ offset=0,
272
+ sep_style=SeparatorStyle.TWO,
273
+ sep=" ",
274
+ sep2="</s>",
275
+ )
276
+
277
+ conv_llama_2 = Conversation(
278
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
279
+
280
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
281
+ roles=("USER", "ASSISTANT"),
282
+ version="llama_v2",
283
+ messages=(),
284
+ offset=0,
285
+ sep_style=SeparatorStyle.LLAMA_2,
286
+ sep="<s>",
287
+ sep2="</s>",
288
+ )
289
+
290
+ conv_llava_llama_2 = Conversation(
291
+ system="You are a helpful language and vision assistant. "
292
+ "You are able to understand the visual content that the user provides, "
293
+ "and assist the user with a variety of tasks using natural language.",
294
+ roles=("USER", "ASSISTANT"),
295
+ version="llama_v2",
296
+ messages=(),
297
+ offset=0,
298
+ sep_style=SeparatorStyle.LLAMA_2,
299
+ sep="<s>",
300
+ sep2="</s>",
301
+ )
302
+
303
+ conv_mpt = Conversation(
304
+ system="""<|im_start|>system
305
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
306
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
307
+ version="mpt",
308
+ messages=(),
309
+ offset=0,
310
+ sep_style=SeparatorStyle.MPT,
311
+ sep="<|im_end|>",
312
+ )
313
+
314
+ conv_llava_plain = Conversation(
315
+ system="",
316
+ roles=("", ""),
317
+ messages=(
318
+ ),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.PLAIN,
321
+ sep="\n",
322
+ )
323
+
324
+ conv_llava_v0 = Conversation(
325
+ system="A chat between a curious human and an artificial intelligence assistant. "
326
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
327
+ roles=("Human", "Assistant"),
328
+ messages=(
329
+ ),
330
+ offset=0,
331
+ sep_style=SeparatorStyle.SINGLE,
332
+ sep="###",
333
+ )
334
+
335
+ conv_llava_v0_mmtag = Conversation(
336
+ system="A chat between a curious user and an artificial intelligence assistant. "
337
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
338
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
339
+ roles=("Human", "Assistant"),
340
+ messages=(
341
+ ),
342
+ offset=0,
343
+ sep_style=SeparatorStyle.SINGLE,
344
+ sep="###",
345
+ version="v0_mmtag",
346
+ )
347
+
348
+ conv_llava_v1 = Conversation(
349
+ system="A chat between a curious human and an artificial intelligence assistant. "
350
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
351
+ roles=("USER", "ASSISTANT"),
352
+ version="v1",
353
+ messages=(),
354
+ offset=0,
355
+ sep_style=SeparatorStyle.TWO,
356
+ sep=" ",
357
+ sep2="</s>",
358
+ )
359
+
360
+ conv_llava_v1_mmtag = Conversation(
361
+ system="A chat between a curious user and an artificial intelligence assistant. "
362
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
363
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
364
+ roles=("USER", "ASSISTANT"),
365
+ messages=(),
366
+ offset=0,
367
+ sep_style=SeparatorStyle.TWO,
368
+ sep=" ",
369
+ sep2="</s>",
370
+ version="v1_mmtag",
371
+ )
372
+
373
+ default_conversation = conv_vicuna_v0
374
+ conv_templates = {
375
+ "default": conv_vicuna_v0,
376
+ "v0": conv_vicuna_v0,
377
+ "v1": conv_vicuna_v1,
378
+ "vicuna_v1": conv_vicuna_v1,
379
+ "llama_2": conv_llama_2,
380
+
381
+ "plain": conv_llava_plain,
382
+ "v0_plain": conv_llava_plain,
383
+ "llava_v0": conv_llava_v0,
384
+ "v0_mmtag": conv_llava_v0_mmtag,
385
+ "llava_v1": conv_llava_v1,
386
+ "v1_mmtag": conv_llava_v1_mmtag,
387
+ "llava_llama_2": conv_llava_llama_2,
388
+
389
+ "mpt": conv_mpt,
390
+ }
391
+
392
+ class Chat:
393
+ def __init__(self, model, image_processor,tokenizer, device='cuda:0', stopping_criteria=None):
394
+ self.device = device
395
+ self.model = model
396
+ self.vis_processor = image_processor
397
+ self.tokenizer=tokenizer
398
+
399
+ # if stopping_criteria is not None:
400
+ # self.stopping_criteria = stopping_criteria
401
+ # else:
402
+ # stop_words_ids = [torch.tensor([2]).to(self.device)]
403
+ # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
404
+
405
+ def ask(self, text, conv):
406
+ # import pdb;pdb.set_trace()
407
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
408
+ and conv.messages[-1][1][-9:] == '<image>\n': # last message is image.
409
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
410
+ else:
411
+ conv.append_message(conv.roles[0], text)
412
+
413
+ def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
414
+ repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
415
+ conv.append_message(conv.roles[1], None)
416
+ prompt = conv.get_prompt()
417
+ # prompt='A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions. USER: <image>\n hello ASSISTANT:'
418
+ text_input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device=self.device)
419
+
420
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
421
+ keywords = [stop_str]
422
+ stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, text_input_ids)
423
+ current_max_len = text_input_ids.shape[1] + max_new_tokens
424
+ if current_max_len - max_length > 0:
425
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
426
+ 'The model will not see the contexts outside the range.')
427
+ begin_idx = max(0, current_max_len - max_length)
428
+ embs = text_input_ids[:, begin_idx:]
429
+
430
+ generation_kwargs = dict(
431
+ input_ids=embs,
432
+ images=img_list[0],
433
+ max_new_tokens=max_new_tokens,
434
+ stopping_criteria=[stopping_criteria],
435
+ num_beams=num_beams,
436
+ do_sample=True,
437
+ min_length=min_length,
438
+ top_p=top_p,
439
+ use_cache=True,
440
+ repetition_penalty=repetition_penalty,
441
+ length_penalty=length_penalty,
442
+ temperature=float(temperature),
443
+ )
444
+ return generation_kwargs
445
+
446
+ # def answer(self, conv, img_list, **kargs):
447
+ # generation_dict = self.answer_prepare(conv, img_list, **kargs)
448
+ # output_token = self.model_generate(**generation_dict)[0]
449
+ # output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
450
+
451
+ # output_text = output_text.split('###')[0] # remove the stop sign '###'
452
+ # output_text = output_text.split('Assistant:')[-1].strip()
453
+
454
+ # conv.messages[-1][1] = output_text
455
+ # return output_text, output_token.cpu().numpy()
456
+
457
+ def stream_answer(self, conv, img_list, **kargs):
458
+ generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
459
+
460
+ streamer = TextIteratorStreamer(self.tokenizer,skip_prompt=True, skip_special_tokens=True)
461
+ generation_kwargs['streamer'] = streamer
462
+ # import pdb;pdb.set_trace()
463
+ # output_ids=self.model.generate(*generation_kwargs)
464
+ output=self.model_generate(kwargs=generation_kwargs)
465
+ # thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
466
+ # thread.start()
467
+ return streamer
468
+
469
+ def model_generate(self, *args, **kwargs):
470
+ # for 8 bit and 16 bit compatibility
471
+ with torch.inference_mode():
472
+ output = self.model.generate(kwargs['kwargs']['input_ids'],
473
+ images=kwargs['kwargs']['images'],
474
+ do_sample=False,
475
+ temperature=kwargs['kwargs']['temperature'],
476
+ max_new_tokens=kwargs['kwargs']['max_new_tokens'],
477
+ streamer=kwargs['kwargs']['streamer'],
478
+ use_cache=kwargs['kwargs']['use_cache'],
479
+ stopping_criteria=kwargs['kwargs']['stopping_criteria'])
480
+ # import pdb;pdb.set_trace()
481
+ # print(output)
482
+ outputs = self.tokenizer.decode(output[0,kwargs['kwargs']['input_ids'].shape[1]:]).strip()
483
+ # print(outputs)
484
+ return output
485
+
486
+ def encode_img(self, img_list):
487
+
488
+ image = img_list[0]
489
+ # image='/share/data/drive_3/kartik/LLaVA/output_images/output.jpg'
490
+ img_list.pop(0)
491
+ if isinstance(image, str): # is a image path
492
+ raw_image = Image.open(image).convert('RGB')
493
+ image = process_images_demo([raw_image], self.vis_processor)
494
+ # print("raw")
495
+ # image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
496
+ elif isinstance(image, Image.Image):
497
+ raw_image = image
498
+ image = process_images_demo([raw_image], self.vis_processor )
499
+ image=image.to(device=self.device,dtype=torch.float16)
500
+ # print("Image")
501
+ # image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
502
+ elif isinstance(image, torch.Tensor):
503
+ if len(image.shape) == 3:
504
+ image = image.unsqueeze(0)
505
+ image = image.to(self.device)
506
+
507
+ # image_emb, _ = self.model.encode_img(image)
508
+ img_list.append(image)
509
+
510
+ def upload_img(self, image, conv, img_list):
511
+ conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN+'\n')
512
+ img_list.append(image)
513
+ msg = "Received."
514
+
515
+ return msg
516
+
517
+
518
+
519
+ # if __name__ == "__main__":
520
+ # print(default_conversation.get_prompt())
Geo/GeochatP-main/geochat/eval/batch_geochat_grounding.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from geochat.conversation import conv_templates, SeparatorStyle
10
+ from geochat.model.builder import load_pretrained_model
11
+ from geochat.utils import disable_torch_init
12
+ from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from PIL import Image
15
+ import math
16
+ def split_list(lst, n):
17
+ """Split a list into n (roughly) equal-sized chunks"""
18
+ chunk_size = math.ceil(len(lst) / n) # integer division
19
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
20
+
21
+
22
+ def get_chunk(lst, n, k):
23
+ chunks = split_list(lst, n)
24
+ return chunks[k]
25
+
26
+
27
+ def eval_model(args):
28
+ # Model
29
+ disable_torch_init()
30
+ model_path = os.path.expanduser(args.model_path)
31
+ model_name = get_model_name_from_path(model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
33
+ import pdb;pdb.set_trace()
34
+ # print(model)
35
+ questions=[]
36
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
37
+
38
+
39
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
40
+ answers_file = os.path.expanduser(args.answers_file)
41
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
42
+
43
+ ans_file = open(answers_file, "w")
44
+
45
+ for i in tqdm(range(0,len(questions),args.batch_size)):
46
+ input_batch=[]
47
+ input_image_batch=[]
48
+ count=i
49
+ image_folder=[]
50
+ batch_end = min(i + args.batch_size, len(questions))
51
+
52
+
53
+ for j in range(i,batch_end):
54
+ image_file=questions[j]['image_id']+'.png'
55
+
56
+ if questions[j]['type']=='ref':
57
+ qs="[refer] Give me the location of <p> " + qs+" </p>"
58
+ else:
59
+ qs="[grounding]" + qs
60
+
61
+ if model.config.mm_use_im_start_end:
62
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
63
+ else:
64
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
65
+
66
+ conv = conv_templates[args.conv_mode].copy()
67
+ conv.append_message(conv.roles[0], qs)
68
+ conv.append_message(conv.roles[1], None)
69
+ prompt = conv.get_prompt()
70
+
71
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
72
+ input_batch.append(input_ids)
73
+
74
+ image = Image.open(os.path.join(args.image_folder, image_file))
75
+
76
+ image_folder.append(image)
77
+
78
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
79
+ keywords = [stop_str]
80
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
81
+
82
+ max_length = max(tensor.size(1) for tensor in input_batch)
83
+
84
+ final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
85
+ final_input_tensors=torch.cat(final_input_list,dim=0)
86
+ image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
87
+
88
+ with torch.inference_mode():
89
+ output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
90
+
91
+ input_token_len = final_input_tensors.shape[1]
92
+ n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
93
+ if n_diff_input_output > 0:
94
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
95
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
96
+ for k in range(0,len(final_input_list)):
97
+ output = outputs[k].strip()
98
+ if output.endswith(stop_str):
99
+ output = output[:-len(stop_str)]
100
+ output = output.strip()
101
+
102
+ ans_id = shortuuid.uuid()
103
+
104
+ ans_file.write(json.dumps({
105
+
106
+ "question_id": questions[count]["question_id"],
107
+ "image_id": questions[count]["image_id"],
108
+ "answer": output,
109
+ "ground_truth": questions[count]['ground_truth'],
110
+ "question":questions[count]['question'],
111
+ "type": questions[count]['type'],
112
+ "dataset": questions[count]['dataset'],
113
+ "obj_ids": questions[count]['obj_ids'],
114
+ "size_group": questions[count]['size_group'],
115
+
116
+ }) + "\n")
117
+ count=count+1
118
+ ans_file.flush()
119
+ ans_file.close()
120
+
121
+
122
+ if __name__ == "__main__":
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
125
+ parser.add_argument("--model-base", type=str, default=None)
126
+ parser.add_argument("--image-folder", type=str, default="")
127
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
128
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
129
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
130
+ parser.add_argument("--num-chunks", type=int, default=1)
131
+ parser.add_argument("--chunk-idx", type=int, default=0)
132
+ parser.add_argument("--temperature", type=float, default=0.2)
133
+ parser.add_argument("--top_p", type=float, default=None)
134
+ parser.add_argument("--num_beams", type=int, default=1)
135
+ parser.add_argument("--batch_size",type=int, default=1)
136
+ args = parser.parse_args()
137
+
138
+ eval_model(args)
Geo/GeochatP-main/geochat/eval/batch_geochat_referring.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from geochat.conversation import conv_templates, SeparatorStyle
10
+ from geochat.model.builder import load_pretrained_model
11
+ from geochat.utils import disable_torch_init
12
+ from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from PIL import Image
15
+ import math
16
+ def split_list(lst, n):
17
+ """Split a list into n (roughly) equal-sized chunks"""
18
+ chunk_size = math.ceil(len(lst) / n) # integer division
19
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
20
+
21
+
22
+ def get_chunk(lst, n, k):
23
+ chunks = split_list(lst, n)
24
+ return chunks[k]
25
+
26
+
27
+ def eval_model(args):
28
+ # Model
29
+ disable_torch_init()
30
+ model_path = os.path.expanduser(args.model_path)
31
+ model_name = get_model_name_from_path(model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
33
+ # print(model)
34
+ questions=[]
35
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
36
+
37
+
38
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
39
+ answers_file = os.path.expanduser(args.answers_file)
40
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
41
+
42
+ ans_file = open(answers_file, "w")
43
+
44
+ for i in tqdm(range(0,len(questions),args.batch_size)):
45
+ input_batch=[]
46
+ input_image_batch=[]
47
+ count=i
48
+ image_folder=[]
49
+ batch_end = min(i + args.batch_size, len(questions))
50
+
51
+
52
+ for j in range(i,batch_end):
53
+ image_file=questions[j]['image_id']+'.png'
54
+ qs="[identify] What is the object present at " + questions[j]['question']
55
+
56
+ if model.config.mm_use_im_start_end:
57
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
58
+ else:
59
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
60
+
61
+ conv = conv_templates[args.conv_mode].copy()
62
+ conv.append_message(conv.roles[0], qs)
63
+ conv.append_message(conv.roles[1], None)
64
+ prompt = conv.get_prompt()
65
+
66
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
67
+ input_batch.append(input_ids)
68
+
69
+ image = Image.open(os.path.join(args.image_folder, image_file))
70
+
71
+ image_folder.append(image)
72
+
73
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
74
+ keywords = [stop_str]
75
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
76
+
77
+ max_length = max(tensor.size(1) for tensor in input_batch)
78
+
79
+ final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
80
+ final_input_tensors=torch.cat(final_input_list,dim=0)
81
+ image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
82
+
83
+ with torch.inference_mode():
84
+ output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
85
+
86
+ input_token_len = final_input_tensors.shape[1]
87
+ n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
88
+ if n_diff_input_output > 0:
89
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
90
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
91
+ for k in range(0,len(final_input_list)):
92
+ output = outputs[k].strip()
93
+ if output.endswith(stop_str):
94
+ output = output[:-len(stop_str)]
95
+ output = output.strip()
96
+
97
+ ans_id = shortuuid.uuid()
98
+
99
+ ans_file.write(json.dumps({
100
+ "question_id": questions[count]["question_id"],
101
+ "image_id": questions[count]["image_id"],
102
+ "answer": output,
103
+ "ground_truth": questions[count]['ground_truth'],
104
+ "question":questions[count]['question'],
105
+ "type": questions[count]['type'],
106
+ "dataset": questions[count]['dataset'],
107
+ "obj_ids": questions[count]['obj_ids'],
108
+ "size_group": questions[count]['size_group'],
109
+
110
+ }) + "\n")
111
+ count=count+1
112
+ ans_file.flush()
113
+ ans_file.close()
114
+
115
+
116
+ if __name__ == "__main__":
117
+ parser = argparse.ArgumentParser()
118
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
119
+ parser.add_argument("--model-base", type=str, default=None)
120
+ parser.add_argument("--image-folder", type=str, default="")
121
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
122
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
123
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
124
+ parser.add_argument("--num-chunks", type=int, default=1)
125
+ parser.add_argument("--chunk-idx", type=int, default=0)
126
+ parser.add_argument("--temperature", type=float, default=0.2)
127
+ parser.add_argument("--top_p", type=float, default=None)
128
+ parser.add_argument("--num_beams", type=int, default=1)
129
+ parser.add_argument("--batch_size",type=int, default=1)
130
+ args = parser.parse_args()
131
+
132
+ eval_model(args)
Geo/GeochatP-main/geochat/eval/batch_geochat_scene.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from geochat.conversation import conv_templates, SeparatorStyle
10
+ from geochat.model.builder import load_pretrained_model
11
+ from geochat.utils import disable_torch_init
12
+ from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from PIL import Image
15
+ import math
16
+
17
+ def evaluation_metrics(data_path):
18
+
19
+ base = [json.loads(q) for q in open(data_path, "r")]
20
+ correct=0
21
+ incorrect=0
22
+ for answers in tqdm(base):
23
+ gt=answers['question_id'].split('/')[0].lower()
24
+ answer=answers['answer'].replace(' ','').lower().replace('.','')
25
+ if gt==answer:
26
+ correct=correct+1
27
+ else:
28
+ incorrect=incorrect+1
29
+ # else:
30
+ # continue
31
+ print('correct:',correct)
32
+ print('incorrect:',incorrect)
33
+ print('Total:',correct+incorrect)
34
+ print('Acc:',(correct/(correct+incorrect)))
35
+
36
+
37
+
38
+
39
+ def eval_model(args):
40
+ # Model
41
+ disable_torch_init()
42
+ model_path = os.path.expanduser(args.model_path)
43
+ model_name = get_model_name_from_path(model_path)
44
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
45
+ # print(model)
46
+ questions=[]
47
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
48
+
49
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
50
+ answers_file = os.path.expanduser(args.answers_file)
51
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
52
+
53
+ ans_file = open(answers_file, "w")
54
+
55
+ for i in tqdm(range(0,len(questions),args.batch_size)):
56
+ input_batch=[]
57
+ input_image_batch=[]
58
+ count=i
59
+ image_folder=[]
60
+ batch_end = min(i + args.batch_size, len(questions))
61
+
62
+
63
+ for j in range(i,batch_end):
64
+ image_file=questions[j]['image']
65
+ qs=questions[j]['text']
66
+
67
+ if model.config.mm_use_im_start_end:
68
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
69
+ else:
70
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
71
+
72
+ conv = conv_templates[args.conv_mode].copy()
73
+ conv.append_message(conv.roles[0], qs)
74
+ conv.append_message(conv.roles[1], None)
75
+ prompt = conv.get_prompt()
76
+
77
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
78
+ input_batch.append(input_ids)
79
+
80
+ image = Image.open(os.path.join(args.image_folder, image_file))
81
+
82
+ image_folder.append(image)
83
+
84
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
85
+ keywords = [stop_str]
86
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
87
+
88
+ max_length = max(tensor.size(1) for tensor in input_batch)
89
+
90
+ final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
91
+ final_input_tensors=torch.cat(final_input_list,dim=0)
92
+ image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
93
+
94
+ with torch.inference_mode():
95
+ output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
96
+
97
+ input_token_len = final_input_tensors.shape[1]
98
+ n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
99
+ if n_diff_input_output > 0:
100
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
101
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
102
+ for k in range(0,len(final_input_list)):
103
+ output = outputs[k].strip()
104
+ if output.endswith(stop_str):
105
+ output = output[:-len(stop_str)]
106
+ output = output.strip()
107
+
108
+ ans_id = shortuuid.uuid()
109
+
110
+ ans_file.write(json.dumps({
111
+
112
+ "question_id": questions[count]["question_id"],
113
+ "image_id": questions[count]["image"],
114
+ "answer": output,
115
+ "ground_truth": questions[count]['ground_truth']
116
+ }) + "\n")
117
+ count=count+1
118
+ ans_file.flush()
119
+ ans_file.close()
120
+ evaluation_metrics(answers_file)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
126
+ parser.add_argument("--model-base", type=str, default=None)
127
+ parser.add_argument("--image-folder", type=str, default="")
128
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
129
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
130
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
131
+ parser.add_argument("--num-chunks", type=int, default=1)
132
+ parser.add_argument("--chunk-idx", type=int, default=0)
133
+ parser.add_argument("--temperature", type=float, default=0.2)
134
+ parser.add_argument("--top_p", type=float, default=None)
135
+ parser.add_argument("--num_beams", type=int, default=1)
136
+ parser.add_argument("--batch_size",type=int, default=1)
137
+ args = parser.parse_args()
138
+
139
+ eval_model(args)
Geo/GeochatP-main/geochat/eval/batch_geochat_vqa.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from geochat.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from geochat.conversation import conv_templates, SeparatorStyle
10
+ from geochat.model.builder import load_pretrained_model
11
+ from geochat.utils import disable_torch_init
12
+ from geochat.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from PIL import Image
15
+ import math
16
+ def split_list(lst, n):
17
+ """Split a list into n (roughly) equal-sized chunks"""
18
+ chunk_size = math.ceil(len(lst) / n) # integer division
19
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
20
+
21
+
22
+ def get_chunk(lst, n, k):
23
+ chunks = split_list(lst, n)
24
+ return chunks[k]
25
+
26
+
27
+ def eval_model(args):
28
+ # Model
29
+ disable_torch_init()
30
+ model_path = os.path.expanduser(args.model_path)
31
+ model_name = get_model_name_from_path(model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
33
+
34
+ questions=[]
35
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
36
+
37
+
38
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
39
+ answers_file = os.path.expanduser(args.answers_file)
40
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
41
+
42
+ ans_file = open(answers_file, "w")
43
+
44
+ for i in tqdm(range(0,len(questions),args.batch_size)):
45
+ input_batch=[]
46
+ input_image_batch=[]
47
+ count=i
48
+ image_folder=[]
49
+ batch_end = min(i + args.batch_size, len(questions))
50
+
51
+
52
+ for j in range(i,batch_end):
53
+ image_file=questions[j]['image']
54
+ qs=questions[j]['text']
55
+
56
+ if model.config.mm_use_im_start_end:
57
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
58
+ else:
59
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
60
+
61
+ conv = conv_templates[args.conv_mode].copy()
62
+ conv.append_message(conv.roles[0], qs)
63
+ conv.append_message(conv.roles[1], None)
64
+ prompt = conv.get_prompt()
65
+
66
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
67
+ input_batch.append(input_ids)
68
+
69
+ image = Image.open(os.path.join(args.image_folder, image_file))
70
+
71
+ image_folder.append(image)
72
+
73
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
74
+ keywords = [stop_str]
75
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
76
+
77
+ max_length = max(tensor.size(1) for tensor in input_batch)
78
+
79
+ final_input_list = [torch.cat((torch.zeros((1,max_length - tensor.size(1)), dtype=tensor.dtype,device=tensor.get_device()), tensor),dim=1) for tensor in input_batch]
80
+ final_input_tensors=torch.cat(final_input_list,dim=0)
81
+ image_tensor_batch = image_processor.preprocess(image_folder,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values']
82
+
83
+ with torch.inference_mode():
84
+ output_ids = model.generate( final_input_tensors, images=image_tensor_batch.half().cuda(), do_sample=False , temperature=args.temperature, top_p=args.top_p, num_beams=1, max_new_tokens=256,length_penalty=2.0, use_cache=True)
85
+
86
+ input_token_len = final_input_tensors.shape[1]
87
+ n_diff_input_output = (final_input_tensors != output_ids[:, :input_token_len]).sum().item()
88
+ if n_diff_input_output > 0:
89
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
90
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
91
+ for k in range(0,len(final_input_list)):
92
+ output = outputs[k].strip()
93
+ if output.endswith(stop_str):
94
+ output = output[:-len(stop_str)]
95
+ output = output.strip()
96
+
97
+ ans_id = shortuuid.uuid()
98
+
99
+ ans_file.write(json.dumps({
100
+ "question_id": questions[count]["question_id"],
101
+ "image_id": questions[count]["image"],
102
+ "answer": output,
103
+ }) + "\n")
104
+ count=count+1
105
+ ans_file.flush()
106
+ ans_file.close()
107
+
108
+
109
+ if __name__ == "__main__":
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
112
+ parser.add_argument("--model-base", type=str, default=None)
113
+ parser.add_argument("--image-folder", type=str, default="")
114
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
115
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
116
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
117
+ parser.add_argument("--num-chunks", type=int, default=1)
118
+ parser.add_argument("--chunk-idx", type=int, default=0)
119
+ parser.add_argument("--temperature", type=float, default=0.2)
120
+ parser.add_argument("--top_p", type=float, default=None)
121
+ parser.add_argument("--num_beams", type=int, default=1)
122
+ parser.add_argument("--batch_size",type=int, default=1)
123
+ args = parser.parse_args()
124
+
125
+ eval_model(args)
Geo/GeochatP-main/geochat/mm_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from geochat.constants import IMAGE_TOKEN_INDEX
8
+ import numpy as np
9
+
10
+ def load_image_from_base64(image):
11
+ return Image.open(BytesIO(base64.b64decode(image)))
12
+
13
+
14
+ def expand2square(pil_img, background_color):
15
+ width, height = pil_img.size
16
+ if width == height:
17
+ return pil_img
18
+ elif width > height:
19
+ result = Image.new(pil_img.mode, (width, width), background_color)
20
+ result.paste(pil_img, (0, (width - height) // 2))
21
+ return result
22
+ else:
23
+ result = Image.new(pil_img.mode, (height, height), background_color)
24
+ result.paste(pil_img, ((height - width) // 2, 0))
25
+ return result
26
+
27
+
28
+ def process_images(images, image_processor, model_cfg):
29
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
30
+ new_images = []
31
+ if image_aspect_ratio == 'pad':
32
+ for image in images:
33
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
34
+ image = image_processor.preprocess(image,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504},return_tensors='pt')['pixel_values'][0]
35
+ # image = image_processor.preprocess(image,return_tensors='pt')['pixel_values'][0]
36
+
37
+ new_images.append(image)
38
+ else:
39
+ return image_processor(images, return_tensors='pt')['pixel_values']
40
+ if all(x.shape == new_images[0].shape for x in new_images):
41
+ new_images = torch.stack(new_images, dim=0)
42
+ return new_images
43
+
44
+ def process_images_demo(images, image_processor):
45
+ new_images = []
46
+ # image_aspect_ratio = 'pad'
47
+ for image in images:
48
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
49
+ image = image_processor.preprocess(image,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504},return_tensors='pt')['pixel_values'][0]
50
+ # image = image_processor.preprocess(image,return_tensors='pt')['pixel_values'][0]
51
+
52
+ new_images.append(image)
53
+
54
+ if all(x.shape == new_images[0].shape for x in new_images):
55
+ new_images = torch.stack(new_images, dim=0)
56
+ return new_images
57
+
58
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
59
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
60
+
61
+ def insert_separator(X, sep):
62
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
63
+
64
+ input_ids = []
65
+ offset = 0
66
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
67
+ offset = 1
68
+ input_ids.append(prompt_chunks[0][0])
69
+
70
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
71
+ input_ids.extend(x[offset:])
72
+
73
+ if return_tensors is not None:
74
+ if return_tensors == 'pt':
75
+ return torch.tensor(input_ids, dtype=torch.long)
76
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
77
+ return input_ids
78
+
79
+
80
+ def get_model_name_from_path(model_path):
81
+ model_path = model_path.strip("/")
82
+ model_paths = model_path.split("/")
83
+ if model_paths[-1].startswith('checkpoint-'):
84
+ return model_paths[-2] + "_" + model_paths[-1]
85
+ else:
86
+ return model_paths[-1]
87
+
88
+
89
+
90
+
91
+ class KeywordsStoppingCriteria(StoppingCriteria):
92
+ def __init__(self, keywords, tokenizer, input_ids):
93
+ self.keywords = keywords
94
+ self.keyword_ids = []
95
+ self.max_keyword_len = 0
96
+ for keyword in keywords:
97
+ cur_keyword_ids = tokenizer(keyword).input_ids
98
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
99
+ cur_keyword_ids = cur_keyword_ids[1:]
100
+ if len(cur_keyword_ids) > self.max_keyword_len:
101
+ self.max_keyword_len = len(cur_keyword_ids)
102
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
103
+ self.tokenizer = tokenizer
104
+ self.start_len = input_ids.shape[1]
105
+
106
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
107
+ # assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
108
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
109
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
110
+ for keyword_id in self.keyword_ids:
111
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
112
+ return True
113
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
114
+ flag=False
115
+ for output in outputs:
116
+
117
+ for keyword in self.keywords:
118
+ if keyword in output:
119
+ flag=True
120
+ return flag
121
+ return flag
Geo/GeochatP-main/geochat/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .language_model.geochat_llama import GeoChatLlamaForCausalLM, GeoChatConfig
2
+ from .language_model.geochat_mpt import GeoChatMPTForCausalLM, GeoChatMPTConfig
Geo/GeochatP-main/geochat/model/apply_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from geochat import GeoChatLlamaForCausalLM
11
+
12
+
13
+ def apply_delta(base_model_path, target_model_path, delta_path):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
+ bparam = base.state_dict()[name]
33
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
+
35
+ print("Saving target model")
36
+ delta.save_pretrained(target_model_path)
37
+ delta_tokenizer.save_pretrained(target_model_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+
46
+ args = parser.parse_args()
47
+
48
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
Geo/GeochatP-main/geochat/model/builder.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from geochat.model import *
23
+ from geochat.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
27
+ kwargs = {"device_map": device_map}
28
+
29
+ if load_8bit:
30
+ kwargs['load_in_8bit'] = True
31
+ elif load_4bit:
32
+ kwargs['load_in_4bit'] = True
33
+ kwargs['quantization_config'] = BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ bnb_4bit_compute_dtype=torch.float16,
36
+ bnb_4bit_use_double_quant=True,
37
+ bnb_4bit_quant_type='nf4'
38
+ )
39
+ else:
40
+ kwargs['torch_dtype'] = torch.float16
41
+
42
+ if 'geochat' in model_name.lower():
43
+ # Load LLaVA model
44
+ if 'lora' in model_name.lower() and model_base is None:
45
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
46
+ if 'lora' in model_name.lower() and model_base is not None:
47
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
48
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
49
+ print('Loading Geochat from base model...')
50
+ model = GeoChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
51
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
52
+ if model.lm_head.weight.shape[0] != token_num:
53
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
54
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
55
+
56
+ print('Loading additional GeoChat weights...')
57
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
58
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
59
+ else:
60
+ # this is probably from HF Hub
61
+ from huggingface_hub import hf_hub_download
62
+ def load_from_hf(repo_id, filename, subfolder=None):
63
+ cache_file = hf_hub_download(
64
+ repo_id=repo_id,
65
+ filename=filename,
66
+ subfolder=subfolder)
67
+ return torch.load(cache_file, map_location='cpu')
68
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
69
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
70
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
71
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
72
+ model.load_state_dict(non_lora_trainables, strict=False)
73
+
74
+ from peft import PeftModel
75
+ print('Loading LoRA weights...')
76
+ model = PeftModel.from_pretrained(model, model_path)
77
+ print('Merging LoRA weights...')
78
+ model = model.merge_and_unload()
79
+ print('Model is loaded...')
80
+ elif model_base is not None:
81
+ # this may be mm projector only
82
+ print('Loading GeoChat from base model...')
83
+ if 'mpt' in model_name.lower():
84
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
85
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
86
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
87
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
88
+ model = GeoChatMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
89
+ else:
90
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
92
+ model = GeoChatLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
93
+
94
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
95
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
96
+ model.load_state_dict(mm_projector_weights, strict=False)
97
+ else:
98
+ if 'mpt' in model_name.lower():
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
100
+ model = GeoChatMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
101
+ else:
102
+ print("Loading GeoChat......")
103
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
104
+ model = GeoChatLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
105
+ else:
106
+ # Load language model
107
+ if model_base is not None:
108
+ # PEFT model
109
+ from peft import PeftModel
110
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
111
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
112
+ print(f"Loading LoRA weights from {model_path}")
113
+ model = PeftModel.from_pretrained(model, model_path)
114
+ print(f"Merging weights")
115
+ model = model.merge_and_unload()
116
+ print('Convert to FP16...')
117
+ model.to(torch.float16)
118
+ else:
119
+ use_fast = False
120
+ if 'mpt' in model_name.lower():
121
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
122
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
123
+ else:
124
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
125
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
126
+
127
+ image_processor = None
128
+
129
+ if 'geochat' in model_name.lower():
130
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
131
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
132
+ if mm_use_im_patch_token:
133
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
134
+ if mm_use_im_start_end:
135
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
136
+ model.resize_token_embeddings(len(tokenizer))
137
+
138
+ vision_tower = model.get_vision_tower()
139
+ if not vision_tower.is_loaded:
140
+ vision_tower.load_model()
141
+ vision_tower.to(device=device, dtype=torch.float16)
142
+ image_processor = vision_tower.image_processor
143
+
144
+ if hasattr(model.config, "max_sequence_length"):
145
+ context_len = model.config.max_sequence_length
146
+ else:
147
+ context_len = 2048
148
+
149
+ return tokenizer, model, image_processor, context_len
Geo/GeochatP-main/geochat/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from geochat.model import *
10
+ from geochat.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
Geo/GeochatP-main/geochat/model/geochat_arch.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+
24
+ from geochat.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+
27
+ class GeoChatMetaModel:
28
+
29
+ def __init__(self, config):
30
+ super(GeoChatMetaModel, self).__init__(config)
31
+
32
+ if hasattr(config, "mm_vision_tower"):
33
+ self.vision_tower = build_vision_tower(config, delay_load=True)
34
+ self.mm_projector = build_vision_projector(config)
35
+
36
+ def get_vision_tower(self):
37
+ vision_tower = getattr(self, 'vision_tower', None)
38
+ if type(vision_tower) is list:
39
+ vision_tower = vision_tower[0]
40
+ return vision_tower
41
+
42
+ def initialize_vision_modules(self, model_args, fsdp=None):
43
+ vision_tower = model_args.vision_tower
44
+ mm_vision_select_layer = model_args.mm_vision_select_layer
45
+ mm_vision_select_feature = model_args.mm_vision_select_feature
46
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
47
+
48
+ self.config.mm_vision_tower = vision_tower
49
+
50
+ if self.get_vision_tower() is None:
51
+ vision_tower = build_vision_tower(model_args)
52
+
53
+ if fsdp is not None and len(fsdp) > 0:
54
+ self.vision_tower = [vision_tower]
55
+ else:
56
+ self.vision_tower = vision_tower
57
+ else:
58
+ if fsdp is not None and len(fsdp) > 0:
59
+ vision_tower = self.vision_tower[0]
60
+ else:
61
+ vision_tower = self.vision_tower
62
+ vision_tower.load_model()
63
+
64
+ self.config.use_mm_proj = True
65
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
66
+ self.config.mm_hidden_size = vision_tower.hidden_size
67
+ self.config.mm_vision_select_layer = mm_vision_select_layer
68
+ self.config.mm_vision_select_feature = mm_vision_select_feature
69
+
70
+ if getattr(self, 'mm_projector', None) is None:
71
+ self.mm_projector = build_vision_projector(self.config)
72
+ # print(mm_projector)
73
+
74
+
75
+ if pretrain_mm_mlp_adapter is not None:
76
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
77
+
78
+ def get_w(weights, keyword):
79
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
80
+
81
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
82
+
83
+
84
+
85
+
86
+ class GeoChatMetaForCausalLM(ABC):
87
+
88
+ @abstractmethod
89
+ def get_model(self):
90
+ pass
91
+
92
+ def get_vision_tower(self):
93
+ return self.get_model().get_vision_tower()
94
+
95
+ def encode_images(self, images):
96
+ image_features = self.get_model().get_vision_tower()(images)
97
+ image_features = self.get_model().mm_projector(image_features)
98
+ return image_features
99
+
100
+ def prepare_inputs_labels_for_multimodal(
101
+ self, input_ids, attention_mask, past_key_values, labels, images
102
+ ):
103
+ vision_tower = self.get_vision_tower()
104
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
105
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
106
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
107
+ return input_ids, attention_mask, past_key_values, None, labels
108
+
109
+ if type(images) is list or images.ndim == 5:
110
+ concat_images = torch.cat([image for image in images], dim=0)
111
+ image_features = self.encode_images(concat_images)
112
+ split_sizes = [image.shape[0] for image in images]
113
+ image_features = torch.split(image_features, split_sizes, dim=0)
114
+ image_features = [x.flatten(0, 1) for x in image_features]
115
+ else:
116
+ image_features = self.encode_images(images)
117
+
118
+ new_input_embeds = []
119
+ new_labels = [] if labels is not None else None
120
+ cur_image_idx = 0
121
+ for batch_idx, cur_input_ids in enumerate(input_ids):
122
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
123
+ # multimodal LLM, but the current sample is not multimodal
124
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
125
+ half_len = cur_input_ids.shape[0] // 2
126
+ cur_image_features = image_features[cur_image_idx]
127
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
128
+ cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
129
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
130
+ new_input_embeds.append(cur_input_embeds)
131
+ if labels is not None:
132
+ new_labels.append(labels[batch_idx])
133
+ cur_image_idx += 1
134
+ continue
135
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
136
+ cur_new_input_embeds = []
137
+ if labels is not None:
138
+ cur_labels = labels[batch_idx]
139
+ cur_new_labels = []
140
+ assert cur_labels.shape == cur_input_ids.shape
141
+ while image_token_indices.numel() > 0:
142
+ cur_image_features = image_features[cur_image_idx]
143
+ image_token_start = image_token_indices[0]
144
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
145
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
146
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
147
+ cur_new_input_embeds.append(cur_image_features)
148
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
149
+ if labels is not None:
150
+ cur_new_labels.append(cur_labels[:image_token_start])
151
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
152
+ cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
153
+ cur_labels = cur_labels[image_token_start+2:]
154
+ else:
155
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
156
+ cur_new_input_embeds.append(cur_image_features)
157
+ if labels is not None:
158
+ cur_new_labels.append(cur_labels[:image_token_start])
159
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
160
+ cur_labels = cur_labels[image_token_start+1:]
161
+ cur_image_idx += 1
162
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
163
+ cur_input_ids = cur_input_ids[image_token_start+2:]
164
+ else:
165
+ cur_input_ids = cur_input_ids[image_token_start+1:]
166
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
167
+ if cur_input_ids.numel() > 0:
168
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
169
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
170
+ else:
171
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
172
+ if labels is not None:
173
+ cur_new_labels.append(cur_labels)
174
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
175
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
176
+ new_input_embeds.append(cur_new_input_embeds)
177
+ if labels is not None:
178
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
179
+ new_labels.append(cur_new_labels)
180
+
181
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
182
+ max_len = max(x.shape[0] for x in new_input_embeds)
183
+
184
+ new_input_embeds_align = []
185
+ for cur_new_embed in new_input_embeds:
186
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
187
+ new_input_embeds_align.append(cur_new_embed)
188
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
189
+
190
+ if labels is not None:
191
+ new_labels_align = []
192
+ _new_labels = new_labels
193
+ for cur_new_label in new_labels:
194
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
195
+ new_labels_align.append(cur_new_label)
196
+ new_labels = torch.stack(new_labels_align, dim=0)
197
+
198
+ if attention_mask is not None:
199
+ new_attention_mask = []
200
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
201
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
202
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
203
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
204
+ new_attention_mask.append(cur_new_attention_mask)
205
+ attention_mask = torch.stack(new_attention_mask, dim=0)
206
+ assert attention_mask.shape == new_labels.shape
207
+ else:
208
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
209
+ if labels is not None:
210
+ new_labels = torch.stack(new_labels, dim=0)
211
+
212
+ if attention_mask is not None:
213
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
214
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
215
+ assert attention_mask.shape == new_input_embeds.shape[:2]
216
+
217
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
218
+
219
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
220
+ if model_args.mm_use_im_patch_token:
221
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
222
+ self.resize_token_embeddings(len(tokenizer))
223
+
224
+ if model_args.mm_use_im_start_end:
225
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
226
+ self.resize_token_embeddings(len(tokenizer))
227
+
228
+ if num_new_tokens > 0:
229
+ input_embeddings = self.get_input_embeddings().weight.data
230
+ output_embeddings = self.get_output_embeddings().weight.data
231
+
232
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
233
+ dim=0, keepdim=True)
234
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
235
+ dim=0, keepdim=True)
236
+
237
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
238
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
239
+
240
+ if model_args.tune_mm_mlp_adapter:
241
+ for p in self.get_input_embeddings().parameters():
242
+ p.requires_grad = True
243
+ for p in self.get_output_embeddings().parameters():
244
+ p.requires_grad = False
245
+
246
+ if model_args.pretrain_mm_mlp_adapter:
247
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
248
+ print(mm_projector_weights)
249
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
250
+ assert num_new_tokens == 2
251
+ if input_embeddings.shape == embed_tokens_weight.shape:
252
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
253
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
254
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
255
+ else:
256
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
257
+ elif model_args.mm_use_im_patch_token:
258
+ if model_args.tune_mm_mlp_adapter:
259
+ for p in self.get_input_embeddings().parameters():
260
+ p.requires_grad = False
261
+ for p in self.get_output_embeddings().parameters():
262
+ p.requires_grad = False
Geo/GeochatP-main/geochat/model/language_model/geochat_llama.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ LlamaConfig, LlamaModel, LlamaForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from ..geochat_arch import GeoChatMetaModel, GeoChatMetaForCausalLM
28
+
29
+
30
+ class GeoChatConfig(LlamaConfig):
31
+ model_type = "geochat"
32
+
33
+
34
+ class GeoChatLlamaModel(GeoChatMetaModel, LlamaModel):
35
+ config_class = GeoChatConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(GeoChatLlamaModel, self).__init__(config)
39
+
40
+
41
+ class GeoChatLlamaForCausalLM(LlamaForCausalLM, GeoChatMetaForCausalLM):
42
+ config_class = GeoChatConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = GeoChatLlamaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ return_dict: Optional[bool] = None,
68
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
69
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70
+ output_hidden_states = (
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
+ )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
76
+
77
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78
+ outputs = self.model(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict
87
+ )
88
+
89
+ hidden_states = outputs[0]
90
+ logits = self.lm_head(hidden_states)
91
+
92
+ loss = None
93
+ if labels is not None:
94
+ # Shift so that tokens < n predict n
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = labels[..., 1:].contiguous()
97
+ # Flatten the tokens
98
+ loss_fct = CrossEntropyLoss()
99
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
+ shift_labels = shift_labels.view(-1)
101
+ # Enable model/pipeline parallelism
102
+ shift_labels = shift_labels.to(shift_logits.device)
103
+ loss = loss_fct(shift_logits, shift_labels)
104
+
105
+ if not return_dict:
106
+ output = (logits,) + outputs[1:]
107
+ return (loss,) + output if loss is not None else output
108
+
109
+ return CausalLMOutputWithPast(
110
+ loss=loss,
111
+ logits=logits,
112
+ past_key_values=outputs.past_key_values,
113
+ hidden_states=outputs.hidden_states,
114
+ attentions=outputs.attentions,
115
+ )
116
+
117
+ def prepare_inputs_for_generation(
118
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
+ ):
120
+ if past_key_values:
121
+ input_ids = input_ids[:, -1:]
122
+
123
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
124
+ if inputs_embeds is not None and past_key_values is None:
125
+ model_inputs = {"inputs_embeds": inputs_embeds}
126
+ else:
127
+ model_inputs = {"input_ids": input_ids}
128
+
129
+ model_inputs.update(
130
+ {
131
+ "past_key_values": past_key_values,
132
+ "use_cache": kwargs.get("use_cache"),
133
+ "attention_mask": attention_mask,
134
+ "images": kwargs.get("images", None),
135
+ }
136
+ )
137
+ return model_inputs
138
+
139
+ AutoConfig.register("geochat", GeoChatConfig)
140
+ AutoModelForCausalLM.register(GeoChatConfig, GeoChatLlamaForCausalLM)
Geo/GeochatP-main/geochat/model/language_model/geochat_mpt.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple
17
+ import warnings
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import math
22
+
23
+ from transformers import AutoConfig, AutoModelForCausalLM
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+
26
+ from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
27
+ from geochat.model.geochat_arch import GeoChatMetaModel, GeoChatMetaForCausalLM
28
+
29
+
30
+ class GeoChatMPTConfig(MPTConfig):
31
+ model_type = "geochat_mpt"
32
+
33
+
34
+ class GeoChatMPTModel(GeoChatMetaModel, MPTModel):
35
+ config_class = GeoChatMPTConfig
36
+
37
+ def __init__(self, config: MPTConfig):
38
+ config.hidden_size = config.d_model
39
+ super(GeoChatMPTModel, self).__init__(config)
40
+
41
+ def embed_tokens(self, x):
42
+ return self.wte(x)
43
+
44
+
45
+ class GeoChatMPTForCausalLM(MPTForCausalLM, GeoChatMetaForCausalLM):
46
+ config_class = GeoChatMPTConfig
47
+ supports_gradient_checkpointing = True
48
+
49
+ def __init__(self, config):
50
+ super(MPTForCausalLM, self).__init__(config)
51
+
52
+ if not config.tie_word_embeddings:
53
+ raise ValueError('MPTForCausalLM only supports tied word embeddings')
54
+ self.transformer = GeoChatMPTModel(config)
55
+ self.logit_scale = None
56
+ if config.logit_scale is not None:
57
+ logit_scale = config.logit_scale
58
+ if isinstance(logit_scale, str):
59
+ if logit_scale == 'inv_sqrt_d_model':
60
+ logit_scale = 1 / math.sqrt(config.d_model)
61
+ else:
62
+ raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
63
+ self.logit_scale = logit_scale
64
+
65
+ def get_model(self):
66
+ return self.transformer
67
+
68
+ def _set_gradient_checkpointing(self, module, value=False):
69
+ if isinstance(module, GeoChatMPTModel):
70
+ module.gradient_checkpointing = value
71
+
72
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
73
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
74
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
75
+
76
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
77
+ outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
78
+ # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
79
+ logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
80
+ if self.logit_scale is not None:
81
+ if self.logit_scale == 0:
82
+ warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
83
+ logits *= self.logit_scale
84
+ loss = None
85
+ if labels is not None:
86
+ labels = torch.roll(labels, shifts=-1)
87
+ labels[:, -1] = -100
88
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
89
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
90
+
91
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
92
+ if inputs_embeds is not None:
93
+ raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
94
+ attention_mask = kwargs['attention_mask'].bool()
95
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
96
+ raise NotImplementedError('MPT does not support generation with right padding.')
97
+ if self.transformer.attn_uses_sequence_id and self.training:
98
+ sequence_id = torch.zeros_like(input_ids[:1])
99
+ else:
100
+ sequence_id = None
101
+ if past_key_values is not None:
102
+ input_ids = input_ids[:, -1].unsqueeze(-1)
103
+ if self.transformer.prefix_lm:
104
+ prefix_mask = torch.ones_like(attention_mask)
105
+ if kwargs.get('use_cache') == False:
106
+ raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
107
+ else:
108
+ prefix_mask = None
109
+ return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)}
110
+
111
+
112
+ AutoConfig.register("geochat_mpt", GeoChatMPTConfig)
113
+ AutoModelForCausalLM.register(GeoChatMPTConfig, GeoChatMPTForCausalLM)
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-311.pyc ADDED
Binary file (3.17 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-38.pyc ADDED
Binary file (2.29 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/adapt_tokenizer.cpython-39.pyc ADDED
Binary file (2.27 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/attention.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/attention.cpython-311.pyc ADDED
Binary file (23.8 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/attention.cpython-38.pyc ADDED
Binary file (11.9 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/attention.cpython-39.pyc ADDED
Binary file (12 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (2.75 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/blocks.cpython-311.pyc ADDED
Binary file (5.09 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/blocks.cpython-38.pyc ADDED
Binary file (2.71 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/blocks.cpython-39.pyc ADDED
Binary file (2.71 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/configuration_mpt.cpython-310.pyc ADDED
Binary file (8.66 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/configuration_mpt.cpython-311.pyc ADDED
Binary file (11.3 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/configuration_mpt.cpython-38.pyc ADDED
Binary file (8.61 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/configuration_mpt.cpython-39.pyc ADDED
Binary file (8.61 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/custom_embedding.cpython-310.pyc ADDED
Binary file (763 Bytes). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/custom_embedding.cpython-311.pyc ADDED
Binary file (1.14 kB). View file
 
Geo/GeochatP-main/geochat/model/language_model/mpt/__pycache__/custom_embedding.cpython-38.pyc ADDED
Binary file (757 Bytes). View file