tokens in prompt")
+
+ images = [load_image_from_base64(image) for image in images]
+ images = process_images(images, image_processor, model.config)
+
+ if type(images) is list:
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
+ else:
+ images = images.to(self.model.device, dtype=torch.float16)
+
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if getattr(self.model.config, 'mm_use_im_start_end', False):
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
+ else:
+ images = None
+ image_args = {"images": images}
+ else:
+ images = None
+ image_args = {}
+
+ temperature = float(params.get("temperature", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
+ stop_str = params.get("stop", None)
+ do_sample = True if temperature > 0.001 else False
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
+ keywords = [stop_str]
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
+
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
+
+ if max_new_tokens < 1:
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
+ return
+
+ thread = Thread(target=model.generate, kwargs=dict(
+ inputs=input_ids,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ max_new_tokens=max_new_tokens,
+ streamer=streamer,
+ stopping_criteria=[stopping_criteria],
+ use_cache=True,
+ **image_args
+ ))
+ thread.start()
+
+ generated_text = ori_prompt
+ for new_text in streamer:
+ generated_text += new_text
+ if generated_text.endswith(stop_str):
+ generated_text = generated_text[:-len(stop_str)]
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
+
+ def generate_stream_gate(self, params):
+ try:
+ for x in self.generate_stream(params):
+ yield x
+ except ValueError as e:
+ print("Caught ValueError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except torch.cuda.CudaError as e:
+ print("Caught torch.cuda.CudaError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except Exception as e:
+ print("Caught Unknown Error", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+app = FastAPI()
+
+
+def release_model_semaphore(fn=None):
+ model_semaphore.release()
+ if fn is not None:
+ fn()
+
+
+@app.post("/worker_generate_stream")
+async def generate_stream(request: Request):
+ global model_semaphore, global_counter
+ global_counter += 1
+ params = await request.json()
+
+ if model_semaphore is None:
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
+ await model_semaphore.acquire()
+ worker.send_heart_beat()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_get_status")
+async def get_status(request: Request):
+ return worker.get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str,
+ default="http://localhost:21002")
+ parser.add_argument("--controller-address", type=str,
+ default="http://localhost:21001")
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--model-name", type=str)
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
+ parser.add_argument("--stream-interval", type=int, default=1)
+ parser.add_argument("--no-register", action="store_true")
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ if args.multi_modal:
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+
+ worker = ModelWorker(args.controller_address,
+ args.worker_address,
+ worker_id,
+ args.no_register,
+ args.model_path,
+ args.model_base,
+ args.model_name,
+ args.load_8bit,
+ args.load_4bit,
+ args.device)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/Geo/GeochatP-main/geochat/serve/register_worker.py b/Geo/GeochatP-main/geochat/serve/register_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c2c40295e0351f25709ba25554c9329f15bf0d2
--- /dev/null
+++ b/Geo/GeochatP-main/geochat/serve/register_worker.py
@@ -0,0 +1,26 @@
+"""
+Manually register workers.
+
+Usage:
+python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
+"""
+
+import argparse
+
+import requests
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--controller-address", type=str)
+ parser.add_argument("--worker-name", type=str)
+ parser.add_argument("--check-heart-beat", action="store_true")
+ args = parser.parse_args()
+
+ url = args.controller_address + "/register_worker"
+ data = {
+ "worker_name": args.worker_name,
+ "check_heart_beat": args.check_heart_beat,
+ "worker_status": None,
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
diff --git a/Geo/GeochatP-main/geochat/serve/test_message.py b/Geo/GeochatP-main/geochat/serve/test_message.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b090faed0e630b03b2294545050f1f4f5032cad
--- /dev/null
+++ b/Geo/GeochatP-main/geochat/serve/test_message.py
@@ -0,0 +1,62 @@
+import argparse
+import json
+
+import requests
+
+from llava.conversation import default_conversation
+
+
+def main():
+ if args.worker_address:
+ worker_addr = args.worker_address
+ else:
+ controller_addr = args.controller_address
+ ret = requests.post(controller_addr + "/refresh_all_workers")
+ ret = requests.post(controller_addr + "/list_models")
+ models = ret.json()["models"]
+ models.sort()
+ print(f"Models: {models}")
+
+ ret = requests.post(controller_addr + "/get_worker_address",
+ json={"model": args.model_name})
+ worker_addr = ret.json()["address"]
+ print(f"worker_addr: {worker_addr}")
+
+ if worker_addr == "":
+ return
+
+ conv = default_conversation.copy()
+ conv.append_message(conv.roles[0], args.message)
+ prompt = conv.get_prompt()
+
+ headers = {"User-Agent": "LLaVA Client"}
+ pload = {
+ "model": args.model_name,
+ "prompt": prompt,
+ "max_new_tokens": args.max_new_tokens,
+ "temperature": 0.7,
+ "stop": conv.sep,
+ }
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
+ json=pload, stream=True)
+
+ print(prompt.replace(conv.sep, "\n"), end="")
+ for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode("utf-8"))
+ output = data["text"].split(conv.sep)[-1]
+ print(output, end="\r")
+ print("")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
+ parser.add_argument("--worker-address", type=str)
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
+ parser.add_argument("--max-new-tokens", type=int, default=32)
+ parser.add_argument("--message", type=str, default=
+ "Tell me a story with more than 1000 words.")
+ args = parser.parse_args()
+
+ main()
diff --git a/Geo/GeochatP-main/geochat/train/geochat_trainer.py b/Geo/GeochatP-main/geochat/train/geochat_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..93fe8043a40c133410b2ae8516fd25c3c1deca21
--- /dev/null
+++ b/Geo/GeochatP-main/geochat/train/geochat_trainer.py
@@ -0,0 +1,175 @@
+import os
+import torch
+
+from torch.utils.data import Sampler
+
+from transformers import Trainer
+from transformers.trainer import (
+ has_length,
+)
+from typing import List, Optional
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ print(name, 'no ignore status')
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def split_to_even_chunks(indices, lengths, num_chunks):
+ """
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
+ """
+
+ if len(indices) % num_chunks != 0:
+ return [indices[i::num_chunks] for i in range(num_chunks)]
+
+ num_indices_per_chunk = len(indices) // num_chunks
+
+ chunks = [[] for _ in range(num_chunks)]
+ chunks_lengths = [0 for _ in range(num_chunks)]
+ for index in indices:
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
+ chunks[shortest_chunk].append(index)
+ chunks_lengths[shortest_chunk] += lengths[index]
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
+ chunks_lengths[shortest_chunk] = float("inf")
+
+ return chunks
+
+
+def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ assert all(l != 0 for l in lengths), "Should not have zero length."
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
+
+ assert len(mm_indices) > 0, "Should have at least one multimodal sample."
+ assert len(lang_indices) > 0, "Should have at least one language sample."
+
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
+ megabatch_size = world_size * batch_size
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ additional_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ if len(additional_batch) >= megabatch_size:
+ megabatches = [additional_batch[:megabatch_size]] + megabatches
+ additional_batch = additional_batch[megabatch_size:]
+
+ if len(additional_batch) > 0:
+ megabatches.append(additional_batch)
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+class LengthGroupedSampler(Sampler):
+ r"""
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+ keeping a bit of randomness.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ world_size: int,
+ lengths: Optional[List[int]] = None,
+ generator=None,
+ group_by_modality: bool = False,
+ ):
+ if lengths is None:
+ raise ValueError("Lengths must be provided.")
+
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.lengths = lengths
+ self.generator = generator
+ self.group_by_modality = group_by_modality
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def __iter__(self):
+ if self.group_by_modality:
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ else:
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ return iter(indices)
+
+
+class GeoChatTrainer(Trainer):
+
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ if self.args.group_by_modality_length:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
+ self.args.train_batch_size,
+ world_size=self.args.world_size,
+ lengths=lengths,
+ group_by_modality=True,
+ )
+ else:
+ return super()._get_train_sampler()
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+
+ # Only save Adapter
+ keys_to_match = ['mm_projector', 'vision_resampler']
+ if getattr(self.args, "use_im_start_end", False):
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
+
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
+ self.model.config.save_pretrained(output_dir)
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
+ else:
+ super(GeoChatTrainer, self)._save_checkpoint(model, trial, metrics)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ pass
+ else:
+ super(GeoChatTrainer, self)._save(output_dir, state_dict)
diff --git a/Geo/GeochatP-main/geochat/train/llama_flash_attn_monkey_patch.py b/Geo/GeochatP-main/geochat/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..31db2eff8d1c4b3ae645583dfc5e156e818b6f1c
--- /dev/null
+++ b/Geo/GeochatP-main/geochat/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,115 @@
+from typing import Optional, Tuple
+import warnings
+
+import torch
+
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
+
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+except ImportError:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ warnings.warn(
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ ) # shape: (b, num_heads, s, head_dim)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+
+ if past_key_value is not None:
+ # reuse k, v
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Transform the data into the format required by flash attention
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
+ cu_q_lens = torch.arange(
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ )
+ max_s = q_len
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = output.view(bsz, q_len, -1)
+ else:
+ qkv = qkv.reshape(bsz, q_len, -1)
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
+ output = pad_input(output_unpad, indices, bsz, q_len)
+
+ return self.o_proj(output), None, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ warnings.warn(
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
+ )
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
+ _prepare_decoder_attention_mask
+ )
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/Geo/GeochatP-main/geochat/train/train.py b/Geo/GeochatP-main/geochat/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2de678368176479561c0528d8e2372a4877f79d
--- /dev/null
+++ b/Geo/GeochatP-main/geochat/train/train.py
@@ -0,0 +1,957 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+from dataclasses import dataclass, field
+import json
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence, List
+
+import torch
+
+import transformers
+
+from geochat.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from torch.utils.data import Dataset
+from geochat.train.geochat_trainer import GeoChatTrainer
+
+from geochat import conversation as conversation_lib
+from geochat.model import *
+from geochat.mm_utils import tokenizer_image_token
+
+from PIL import Image
+
+
+local_rank = None
+
+
+def rank0_print(*args):
+ if local_rank == 0:
+ print(*args)
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ version: Optional[str] = field(default="v0")
+ freeze_backbone: bool = field(default=False)
+ tune_mm_mlp_adapter: bool = field(default=False)
+ vision_tower: Optional[str] = field(default=None)
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
+ mm_projector_type: Optional[str] = field(default='linear')
+ mm_use_im_start_end: bool = field(default=False)
+ mm_use_im_patch_token: bool = field(default=True)
+ mm_vision_select_feature: Optional[str] = field(default="patch")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default=None,
+ metadata={"help": "Path to the training data."})
+ lazy_preprocess: bool = False
+ is_multimodal: bool = False
+ image_folder: Optional[str] = field(default=None)
+ image_aspect_ratio: str = 'square'
+ image_grid_pinpoints: Optional[str] = field(default=None)
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+ freeze_mm_mlp_adapter: bool = field(default=False)
+ mpt_attn_impl: Optional[str] = field(default="triton")
+ model_max_length: int = field(
+ default=512,
+ metadata={
+ "help":
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+ double_quant: bool = field(
+ default=True,
+ metadata={"help": "Compress the quantization statistics through double quantization."}
+ )
+ quant_type: str = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
+ )
+ bits: int = field(
+ default=16,
+ metadata={"help": "How many bits to use."}
+ )
+ lora_enable: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ group_by_modality_length: bool = field(default=False)
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def find_all_linear_names(model):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ names = name.split('.')
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+ if 'lm_head' in lora_module_names: # needed for 16-bit
+ lora_module_names.remove('lm_head')
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
+ output_dir: str):
+ """Collects the state dict and dump to disk."""
+
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False):
+ # Only save Adapter
+ keys_to_match = ['mm_projector']
+ if getattr(trainer.args, "use_im_start_end", False):
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
+ trainer.model.config.save_pretrained(output_dir)
+
+ current_folder = output_dir.split('/')[-1]
+ parent_folder = os.path.dirname(output_dir)
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
+ if current_folder.startswith('checkpoint-'):
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
+ os.makedirs(mm_projector_folder, exist_ok=True)
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
+ else:
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
+ return
+
+ if trainer.deepspeed:
+ torch.cuda.synchronize()
+ trainer.save_model(output_dir)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {
+ key: value.cpu()
+ for key, value in state_dict.items()
+ }
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ) for text in strings
+ ]
+ input_ids = labels = [
+ tokenized.input_ids[0] for tokenized in tokenized_list
+ ]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ from_str = 'unknown'
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
+ sentence["value"] + END_SIGNAL)
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+
+def preprocess_multimodal(
+ sources: Sequence[str],
+ data_args: DataArguments
+) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ for sentence in source:
+
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
+ sentence['value'] = sentence['value'].strip()
+ if "mmtag" in conversation_lib.default_conversation.version:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '' + DEFAULT_IMAGE_TOKEN + '')
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if data_args.mm_use_im_start_end:
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ return sources
+
+
+def preprocess_llama_2(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
+
+ # Mask targets
+ sep = "[/INST] "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_v1(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_mpt(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ targets = input_ids.clone()
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1]
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep)
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
+ for conv_idx in range(3, len(rounds), 2):
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
+ cur_len = 0
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(re_rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_plain(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ assert len(source) == 2
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
+ conversations.append(conversation)
+ # tokenize conversations
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
+ target[:tokenized_len] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
+ return preprocess_plain(sources, tokenizer)
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version.startswith("v1"):
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "mpt":
+ return preprocess_mpt(sources, tokenizer)
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{conversation_lib.default_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+ # tokenize conversations
+ def get_tokenize_len(prompts):
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
+
+ if has_image:
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ else:
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ if has_image:
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
+ else:
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+class LazySupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments):
+ super(LazySupervisedDataset, self).__init__()
+ list_data_dict = json.load(open(data_path, "r"))
+
+ rank0_print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.data_args = data_args
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ img_tokens = 128 if 'image' in sample else 0
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
+ return length_list
+
+ @property
+ def modality_lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
+ cur_len = cur_len if 'image' in sample else -cur_len
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+ if 'image' in sources[0]:
+ image_file = self.list_data_dict[i]['image']
+ image_folder = self.data_args.image_folder
+ processor = self.data_args.image_processor
+ image = Image.open((os.path.join(image_folder, image_file)).strip()).convert('RGB')
+
+ if self.data_args.image_aspect_ratio == 'pad':
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
+ # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ image = processor.preprocess(image,do_resize=True,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values'][0]
+ else:
+ # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ image = processor.preprocess(image,do_resize=True,crop_size ={'height': 504, 'width': 504},size = {'shortest_edge': 504}, return_tensors='pt')['pixel_values'][0]
+
+ sources = preprocess_multimodal(
+ copy.deepcopy([e["conversations"] for e in sources]),
+ self.data_args)
+ else:
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+ data_dict = preprocess(
+ sources,
+ self.tokenizer,
+ has_image=('image' in self.list_data_dict[i]))
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if 'image' in self.list_data_dict[i]:
+ data_dict['image'] = image
+ elif self.data_args.is_multimodal:
+ # image does not exist in the data, but the model is multimodal
+ crop_size = self.data_args.image_processor.crop_size
+ data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
+ return data_dict
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
+ labels = labels[:, :self.tokenizer.model_max_length]
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'image' in instances[0]:
+ images = [instance['image'] for instance in instances]
+ if all(x is not None and x.shape == images[0].shape for x in images):
+ batch['images'] = torch.stack(images)
+ else:
+ batch['images'] = images
+
+ return batch
+
+
+def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
+ data_args) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
+ data_path=data_args.data_path,
+ data_args=data_args)
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
+ return dict(train_dataset=train_dataset,
+ eval_dataset=None,
+ data_collator=data_collator)
+
+
+def train():
+ global local_rank
+
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ local_rank = training_args.local_rank
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+
+ bnb_model_from_pretrained_args = {}
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+ bnb_model_from_pretrained_args.update(dict(
+ device_map={"": training_args.device},
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
+ )
+ ))
+
+ if model_args.vision_tower is not None:
+ if 'mpt' in model_args.model_name_or_path:
+ config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
+ config.attn_config['attn_impl'] = training_args.mpt_attn_impl
+ model = GeoChatMPTForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ )
+ else:
+ model = GeoChatLlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ )
+ else:
+ model = transformers.LlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ )
+ model.config.use_cache = False
+
+ if model_args.freeze_backbone:
+ model.model.requires_grad_(False)
+
+ if training_args.bits in [4, 8]:
+ from peft import prepare_model_for_kbit_training
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
+
+ if training_args.gradient_checkpointing:
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, get_peft_model
+ lora_config = LoraConfig(
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ rank0_print("Adding LoRA adapters...")
+ model = get_peft_model(model, lora_config)
+
+ if 'mpt' in model_args.model_name_or_path:
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right"
+ )
+ else:
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+ if model_args.version == "v0":
+ if tokenizer.pad_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token="[PAD]"),
+ tokenizer=tokenizer,
+ model=model,
+ )
+ elif model_args.version == "v0.5":
+ tokenizer.pad_token = tokenizer.unk_token
+ else:
+ tokenizer.pad_token = tokenizer.unk_token
+ if model_args.version in conversation_lib.conv_templates:
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
+ else:
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
+
+ if model_args.vision_tower is not None:
+ model.get_model().initialize_vision_modules(
+ model_args=model_args,
+ fsdp=training_args.fsdp
+ )
+
+ vision_tower = model.get_vision_tower()
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+
+ data_args.image_processor = vision_tower.image_processor
+ data_args.is_multimodal = True
+
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
+ model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
+
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
+ if model_args.tune_mm_mlp_adapter:
+ model.requires_grad_(False)
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = True
+
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
+ if training_args.freeze_mm_mlp_adapter:
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = False
+
+ if training_args.bits in [4, 8]:
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
+
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
+
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if 'norm' in name:
+ module = module.to(torch.float32)
+ if 'lm_head' in name or 'embed_tokens' in name:
+ if hasattr(module, 'weight'):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
+ data_args=data_args)
+ trainer = GeoChatTrainer(model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ **data_module)
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+
+ model.config.use_cache = True
+
+ if training_args.lora_enable:
+ state_dict = get_peft_state_maybe_zero_3(
+ model.named_parameters(), training_args.lora_bias
+ )
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
+ model.named_parameters()
+ )
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ model.config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer,
+ output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/Geo/GeochatP-main/geochat/train/train_mem.py b/Geo/GeochatP-main/geochat/train/train_mem.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f425f9457168cc83447ab117b1e8ac99557009d
--- /dev/null
+++ b/Geo/GeochatP-main/geochat/train/train_mem.py
@@ -0,0 +1,13 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
+
+# Need to call this before importing transformers.
+from geochat.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
+
+replace_llama_attn_with_flash_attn()
+
+from geochat.train.train import train
+
+if __name__ == "__main__":
+ train()
diff --git a/Geo/GeochatP-main/geochat/utils.py b/Geo/GeochatP-main/geochat/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c75ddbea7dbc0c2f1b20e095dcd47ca9362e8456
--- /dev/null
+++ b/Geo/GeochatP-main/geochat/utils.py
@@ -0,0 +1,126 @@
+import datetime
+import logging
+import logging.handlers
+import os
+import sys
+
+import requests
+
+from geochat.constants import LOGDIR
+
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+
+handler = None
+
+
+def build_logger(logger_name, logger_filename):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ os.makedirs(LOGDIR, exist_ok=True)
+ filename = os.path.join(LOGDIR, logger_filename)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ filename, when='D', utc=True)
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ''
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ''
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == '\n':
+ self.logger.log(self.log_level, line.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != '':
+ self.logger.log(self.log_level, self.linebuf.rstrip())
+ self.linebuf = ''
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ headers = {"Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ flagged = False
+ except KeyError as e:
+ flagged = False
+
+ return flagged
+
+
+def pretty_print_semaphore(semaphore):
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
diff --git a/Geo/GeochatP-main/geochat_demo.py b/Geo/GeochatP-main/geochat_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..939ef3444933f8b4e934fcdd5f4e9cc128acccf7
--- /dev/null
+++ b/Geo/GeochatP-main/geochat_demo.py
@@ -0,0 +1,706 @@
+import argparse
+import os
+import random
+from collections import defaultdict
+
+import cv2
+import re
+import math
+import numpy as np
+from PIL import Image
+import torch
+import html
+import gradio as gr
+
+import torchvision.transforms as T
+import torch.backends.cudnn as cudnn
+
+from geochat.conversation import conv_templates, Chat
+from geochat.model.builder import load_pretrained_model
+from geochat.mm_utils import get_model_name_from_path
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Demo")
+ # parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--gpu-id", type=str,default=0)
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--max-new-tokens", type=int, default=300)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ # args = parser.parse_args()
+ args = parser.parse_args()
+ return args
+
+
+random.seed(42)
+np.random.seed(42)
+torch.manual_seed(42)
+
+cudnn.benchmark = False
+cudnn.deterministic = True
+
+print('Initializing Chat')
+args = parse_args()
+# cfg = Config(args)
+
+model_name = get_model_name_from_path(args.model_path)
+tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+device = 'cuda:{}'.format(args.gpu_id)
+
+# model_config = cfg.model_cfg
+# model_config.device_8bit = args.gpu_id
+# model_cls = registry.get_model_class(model_config.arch)
+# model = model_cls.from_config(model_config).to(device)
+bounding_box_size = 100
+
+# vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
+# vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
+
+model = model.eval()
+
+CONV_VISION = conv_templates['llava_v1'].copy()
+
+def bbox_and_angle_to_polygon(x1, y1, x2, y2, a):
+ # Calculate center coordinates
+ x_ctr = (x1 + x2) / 2
+ y_ctr = (y1 + y2) / 2
+
+ # Calculate width and height
+ w = abs(x2 - x1)
+ h = abs(y2 - y1)
+
+ # Calculate the angle in radians
+ angle_rad = math.radians(a)
+
+ # Calculate coordinates of the four corners of the rotated bounding box
+ cos_a = math.cos(angle_rad)
+ sin_a = math.sin(angle_rad)
+
+ x1_rot = cos_a * (-w / 2) - sin_a * (-h / 2) + x_ctr
+ y1_rot = sin_a * (-w / 2) + cos_a * (-h / 2) + y_ctr
+
+ x2_rot = cos_a * (w / 2) - sin_a * (-h / 2) + x_ctr
+ y2_rot = sin_a * (w / 2) + cos_a * (-h / 2) + y_ctr
+
+ x3_rot = cos_a * (w / 2) - sin_a * (h / 2) + x_ctr
+ y3_rot = sin_a * (w / 2) + cos_a * (h / 2) + y_ctr
+
+ x4_rot = cos_a * (-w / 2) - sin_a * (h / 2) + x_ctr
+ y4_rot = sin_a * (-w / 2) + cos_a * (h / 2) + y_ctr
+
+ # Return the polygon coordinates
+ polygon_coords = np.array((x1_rot, y1_rot, x2_rot, y2_rot, x3_rot, y3_rot, x4_rot, y4_rot))
+
+ return polygon_coords
+
+def rotate_bbox(top_right, bottom_left, angle_degrees):
+ # Convert angle to radians
+ angle_radians = np.radians(angle_degrees)
+
+ # Calculate the center of the rectangle
+ center = ((top_right[0] + bottom_left[0]) / 2, (top_right[1] + bottom_left[1]) / 2)
+
+ # Calculate the width and height of the rectangle
+ width = top_right[0] - bottom_left[0]
+ height = top_right[1] - bottom_left[1]
+
+ # Create a rotation matrix
+ rotation_matrix = cv2.getRotationMatrix2D(center, angle_degrees, 1)
+
+ # Create an array of the rectangle corners
+ rectangle_points = np.array([[bottom_left[0], bottom_left[1]],
+ [top_right[0], bottom_left[1]],
+ [top_right[0], top_right[1]],
+ [bottom_left[0], top_right[1]]], dtype=np.float32)
+
+ # Rotate the rectangle points
+ rotated_rectangle = cv2.transform(np.array([rectangle_points]), rotation_matrix)[0]
+
+ return rotated_rectangle
+def extract_substrings(string):
+ # first check if there is no-finished bracket
+ index = string.rfind('}')
+ if index != -1:
+ string = string[:index + 1]
+
+ pattern = r'(.*?)\}(?!<)'
+ matches = re.findall(pattern, string)
+ substrings = [match for match in matches]
+
+ return substrings
+
+
+def is_overlapping(rect1, rect2):
+ x1, y1, x2, y2 = rect1
+ x3, y3, x4, y4 = rect2
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
+
+
+def computeIoU(bbox1, bbox2):
+ x1, y1, x2, y2 = bbox1
+ x3, y3, x4, y4 = bbox2
+ intersection_x1 = max(x1, x3)
+ intersection_y1 = max(y1, y3)
+ intersection_x2 = min(x2, x4)
+ intersection_y2 = min(y2, y4)
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
+ union_area = bbox1_area + bbox2_area - intersection_area
+ iou = intersection_area / union_area
+ return iou
+
+
+def save_tmp_img(visual_img):
+ file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
+ file_path = "/tmp/gradio" + file_name
+ visual_img.save(file_path)
+ return file_path
+
+
+def mask2bbox(mask):
+ if mask is None:
+ return ''
+ mask = mask.resize([100, 100], resample=Image.NEAREST)
+ mask = np.array(mask)[:, :, 0]
+
+ rows = np.any(mask, axis=1)
+ cols = np.any(mask, axis=0)
+
+ if rows.sum():
+ # Get the top, bottom, left, and right boundaries
+ rmin, rmax = np.where(rows)[0][[0, -1]]
+ cmin, cmax = np.where(cols)[0][[0, -1]]
+ bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
+ else:
+ bbox = ''
+
+ return bbox
+
+
+def escape_markdown(text):
+ # List of Markdown special characters that need to be escaped
+ md_chars = ['<', '>']
+
+ # Escape each special character
+ for char in md_chars:
+ text = text.replace(char, '\\' + char)
+
+ return text
+
+
+def reverse_escape(text):
+ md_chars = ['\\<', '\\>']
+
+ for char in md_chars:
+ text = text.replace(char, char[1:])
+
+ return text
+
+
+colors = [
+ (255, 0, 0),
+ (0, 255, 0),
+ (0, 0, 255),
+ (210, 210, 0),
+ (255, 0, 255),
+ (0, 255, 255),
+ (114, 128, 250),
+ (0, 165, 255),
+ (0, 128, 0),
+ (144, 238, 144),
+ (238, 238, 175),
+ (255, 191, 0),
+ (0, 128, 0),
+ (226, 43, 138),
+ (255, 0, 255),
+ (0, 215, 255),
+]
+
+color_map = {
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
+ color_id, color in enumerate(colors)
+}
+
+used_colors = colors
+
+
+def visualize_all_bbox_together(image, generation):
+ if image is None:
+ return None, ''
+
+ generation = html.unescape(generation)
+
+ image_width, image_height = image.size
+ image = image.resize([500, int(500 / image_width * image_height)])
+ image_width, image_height = image.size
+
+ string_list = extract_substrings(generation)
+ if string_list: # it is grounding or detection
+ mode = 'all'
+ entities = defaultdict(list)
+ i = 0
+ j = 0
+ for string in string_list:
+ try:
+ obj, string = string.split('
')
+ except ValueError:
+ print('wrong string: ', string)
+ continue
+ if "}{" in string:
+ string=string.replace("}{","}{")
+ bbox_list = string.split('')
+ flag = False
+ for bbox_string in bbox_list:
+ integers = re.findall(r'-?\d+', bbox_string)
+ if len(integers)==4:
+ angle=0
+ else:
+ angle=integers[4]
+ integers=integers[:-1]
+
+ if len(integers) == 4:
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
+ left = x0 / bounding_box_size * image_width
+ bottom = y0 / bounding_box_size * image_height
+ right = x1 / bounding_box_size * image_width
+ top = y1 / bounding_box_size * image_height
+
+ entities[obj].append([left, bottom, right, top,angle])
+
+ j += 1
+ flag = True
+ if flag:
+ i += 1
+ else:
+ integers = re.findall(r'-?\d+', generation)
+ # if len(integers)==4:
+ angle=0
+ # else:
+ # angle=integers[4]
+ integers=integers[:-1]
+ if len(integers) == 4: # it is refer
+ mode = 'single'
+
+ entities = list()
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
+ left = x0 / bounding_box_size * image_width
+ bottom = y0 / bounding_box_size * image_height
+ right = x1 / bounding_box_size * image_width
+ top = y1 / bounding_box_size * image_height
+ entities.append([left, bottom, right, top,angle])
+ else:
+ # don't detect any valid bbox to visualize
+ return None, ''
+
+ if len(entities) == 0:
+ return None, ''
+
+ if isinstance(image, Image.Image):
+ image_h = image.height
+ image_w = image.width
+ image = np.array(image)
+
+ elif isinstance(image, str):
+ if os.path.exists(image):
+ pil_img = Image.open(image).convert("RGB")
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
+ image_h = pil_img.height
+ image_w = pil_img.width
+ else:
+ raise ValueError(f"invaild image path, {image}")
+ elif isinstance(image, torch.Tensor):
+
+ image_tensor = image.cpu()
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
+ pil_img = T.ToPILImage()(image_tensor)
+ image_h = pil_img.height
+ image_w = pil_img.width
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
+ else:
+ raise ValueError(f"invalid image format, {type(image)} for {image}")
+
+ indices = list(range(len(entities)))
+
+ new_image = image.copy()
+
+ previous_bboxes = []
+ # size of text
+ text_size = 0.4
+ # thickness of text
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
+ box_line = 2
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
+ base_height = int(text_height * 0.675)
+ text_offset_original = text_height - base_height
+ text_spaces = 2
+
+ # num_bboxes = sum(len(x[-1]) for x in entities)
+ used_colors = colors # random.sample(colors, k=num_bboxes)
+
+ color_id = -1
+ for entity_idx, entity_name in enumerate(entities):
+ if mode == 'single' or mode == 'identify':
+ bboxes = entity_name
+ bboxes = [bboxes]
+ else:
+ bboxes = entities[entity_name]
+ color_id += 1
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm,angle) in enumerate(bboxes):
+ skip_flag = False
+ orig_x1, orig_y1, orig_x2, orig_y2,angle = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm), int(angle)
+
+ color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
+ top_right=(orig_x1,orig_y1)
+ bottom_left=(orig_x2,orig_y2)
+ angle=angle
+ rotated_bbox = rotate_bbox(top_right, bottom_left, angle)
+ new_image=cv2.polylines(new_image, [rotated_bbox.astype(np.int32)], isClosed=True,thickness=2, color=color)
+
+ # new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
+
+ if mode == 'all':
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
+
+ x1 = orig_x1 - l_o
+ y1 = orig_y1 - l_o
+
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
+ x1 = orig_x1 + r_o
+
+ # add text background
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
+ text_line)
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
+ text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
+
+ for prev_bbox in previous_bboxes:
+ if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
+ prev_bbox['phrase'] == entity_name:
+ skip_flag = True
+ break
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
+
+ if text_bg_y2 >= image_h:
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
+ text_bg_y2 = image_h
+ y1 = image_h
+ break
+ if not skip_flag:
+ alpha = 0.5
+ for i in range(text_bg_y1, text_bg_y2):
+ for j in range(text_bg_x1, text_bg_x2):
+ if i < image_h and j < image_w:
+ if j < text_bg_x1 + 1.35 * c_width:
+ # original color
+ bg_color = color
+ else:
+ # white
+ bg_color = [255, 255, 255]
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
+ np.uint8)
+
+ cv2.putText(
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
+ cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
+ )
+
+ previous_bboxes.append(
+ {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
+
+ if mode == 'all':
+ def color_iterator(colors):
+ while True:
+ for color in colors:
+ yield color
+
+ color_gen = color_iterator(colors)
+
+ # Add colors to phrases and remove
+ def colored_phrases(match):
+ phrase = match.group(1)
+ color = next(color_gen)
+ return f'{phrase}'
+
+ generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|', '', generation)
+ generation_colored = re.sub(r'(.*?)
', colored_phrases, generation)
+ else:
+ generation_colored = ''
+
+ pil_image = Image.fromarray(new_image)
+ return pil_image, generation_colored
+
+
+def gradio_reset(chat_state, img_list):
+ if chat_state is not None:
+ chat_state.messages = []
+ if img_list is not None:
+ img_list = []
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
+ interactive=True), chat_state, img_list
+
+
+def image_upload_trigger(upload_flag, replace_flag, img_list):
+ # set the upload flag to true when receive a new image.
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
+ upload_flag = 1
+ if img_list:
+ replace_flag = 1
+ return upload_flag, replace_flag
+
+
+def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
+ # set the upload flag to true when receive a new image.
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
+ upload_flag = 1
+ if img_list or replace_flag == 1:
+ replace_flag = 1
+
+ return upload_flag, replace_flag
+
+
+def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
+ if len(user_message) == 0:
+ text_box_show = 'Input should not be empty!'
+ else:
+ text_box_show = ''
+
+ if isinstance(gr_img, dict):
+ gr_img, mask = gr_img['image'], gr_img['mask']
+ else:
+ mask = None
+
+ if '[identify]' in user_message:
+ # check if user provide bbox in the text input
+ integers = re.findall(r'-?\d+', user_message)
+ if len(integers) != 4: # no bbox in text
+ bbox = mask2bbox(mask)
+ user_message = user_message + bbox
+
+ if chat_state is None:
+ chat_state = CONV_VISION.copy()
+
+ if upload_flag:
+ if replace_flag:
+ chat_state = CONV_VISION.copy() # new image, reset everything
+ replace_flag = 0
+ chatbot = []
+ img_list = []
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
+ upload_flag = 0
+
+ chat.ask(user_message, chat_state)
+
+ chatbot = chatbot + [[user_message, None]]
+
+ if '[identify]' in user_message:
+ visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
+ if visual_img is not None:
+ file_path = save_tmp_img(visual_img)
+ chatbot = chatbot + [[(file_path,), None]]
+
+ return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
+
+
+# def gradio_answer(chatbot, chat_state, img_list, temperature):
+# llm_message = chat.answer(conv=chat_state,
+# img_list=img_list,
+# temperature=temperature,
+# max_new_tokens=500,
+# max_length=2000)[0]
+# chatbot[-1][1] = llm_message
+# return chatbot, chat_state
+
+
+def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
+ if len(img_list) > 0:
+ if not isinstance(img_list[0], torch.Tensor):
+ chat.encode_img(img_list)
+ streamer = chat.stream_answer(conv=chat_state,
+ img_list=img_list,
+ temperature=temperature,
+ max_new_tokens=500,
+ max_length=2000)
+ # chatbot[-1][1] = output
+ # chat_state.messages[-1][1] = ''
+
+ output = ''
+ for new_output in streamer:
+ # print(new_output)
+ output=output+new_output
+ print(output)
+ # if "{" in output:
+ # chatbot[-1][1]="Grounding and referring expression is still under work."
+ # else:
+ output = escape_markdown(output)
+ # output += escapped
+ chatbot[-1][1] = output
+ yield chatbot, chat_state
+ chat_state.messages[-1][1] = ''
+ return chatbot, chat_state
+
+
+def gradio_visualize(chatbot, gr_img):
+ if isinstance(gr_img, dict):
+ gr_img, mask = gr_img['image'], gr_img['mask']
+
+ unescaped = reverse_escape(chatbot[-1][1])
+ visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
+ if visual_img is not None:
+ if len(generation_color):
+ chatbot[-1][1] = generation_color
+ file_path = save_tmp_img(visual_img)
+ chatbot = chatbot + [[None, (file_path,)]]
+
+ return chatbot
+
+
+def gradio_taskselect(idx):
+ prompt_list = [
+ '',
+ 'Classify the image in the following classes: ',
+ '[identify] what is this ',
+ ]
+ instruct_list = [
+ '**Hint:** Type in whatever you want',
+ '**Hint:** Type in the classes you want the model to classify in',
+ '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
+ ]
+ return prompt_list[idx], instruct_list[idx]
+
+
+
+
+chat = Chat(model, image_processor,tokenizer, device=device)
+
+
+title = """GeoChat Demo
"""
+description = 'Welcome to Our GeoChat Chatbot Demo!'
+article = """"""
+# article = """
"""
+
+introduction = '''
+1. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
+2. No Tag: Input whatever you want and CLICK **Send** without any tagging
+
+You can also simply chat in free form!
+'''
+
+
+text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
+ scale=12)
+with gr.Blocks() as demo:
+ gr.Markdown(title)
+ # gr.Markdown(description)
+ gr.Markdown(article)
+
+ with gr.Row():
+ with gr.Column(scale=0.5):
+ image = gr.Image(type="pil", tool='sketch', brush_radius=20)
+
+ temperature = gr.Slider(
+ minimum=0.1,
+ maximum=1.5,
+ value=0.6,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+
+ clear = gr.Button("Restart")
+
+ gr.Markdown(introduction)
+
+ with gr.Column():
+ chat_state = gr.State(value=None)
+ img_list = gr.State(value=[])
+ chatbot = gr.Chatbot(label='GeoChat')
+
+ dataset = gr.Dataset(
+ components=[gr.Textbox(visible=False)],
+ samples=[['No Tag'], ['Scene Classification'],['Identify']],
+ type="index",
+ label='Task Shortcuts',
+ )
+ task_inst = gr.Markdown('**Hint:** Upload your image and chat')
+ with gr.Row():
+ text_input.render()
+ send = gr.Button("Send", variant='primary', size='sm', scale=1)
+
+ upload_flag = gr.State(value=0)
+ replace_flag = gr.State(value=0)
+ image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
+
+ with gr.Row():
+ with gr.Column():
+ gr.Examples(examples=[
+ ["demo_images/train_2956_0001.png", "Where are the airplanes located and what is their type?", upload_flag, replace_flag,
+ img_list],
+ ["demo_images/7292.JPG", "How many buildings are flooded?", upload_flag,
+ replace_flag, img_list],
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
+ outputs=[upload_flag, replace_flag])
+ with gr.Column():
+ gr.Examples(examples=[
+ ["demo_images/church_183.png", "Classify the image in the following classes: Church, Beach, Dense Residential, Storage Tanks.",
+ upload_flag, replace_flag, img_list],
+ ["demo_images/04444.png", "[identify] what is this {<8><26><22><37>}", upload_flag,
+ replace_flag, img_list],
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
+ outputs=[upload_flag, replace_flag])
+
+ dataset.click(
+ gradio_taskselect,
+ inputs=[dataset],
+ outputs=[text_input, task_inst],
+ show_progress="hidden",
+ postprocess=False,
+ queue=False,
+ )
+
+ text_input.submit(
+ gradio_ask,
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
+ ).success(
+ gradio_stream_answer,
+ [chatbot, chat_state, img_list, temperature],
+ [chatbot, chat_state]
+ ).success(
+ gradio_visualize,
+ [chatbot, image],
+ [chatbot],
+ queue=False,
+ )
+
+ send.click(
+ gradio_ask,
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
+ ).success(
+ gradio_stream_answer,
+ [chatbot, chat_state, img_list, temperature],
+ [chatbot, chat_state]
+ ).success(
+ gradio_visualize,
+ [chatbot, image],
+ [chatbot],
+ queue=False,
+ )
+
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
+
+demo.launch(share=True, enable_queue=True,server_name='0.0.0.0')
diff --git a/Geo/GeochatP-main/images/IVAL_logo.png b/Geo/GeochatP-main/images/IVAL_logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..5cd65234cc600ddeb3cdf255477e0044ef8aef5d
Binary files /dev/null and b/Geo/GeochatP-main/images/IVAL_logo.png differ
diff --git a/Geo/GeochatP-main/images/MBZUAI_logo.png b/Geo/GeochatP-main/images/MBZUAI_logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..1aededc586c8cd0a21ad0ceaff0465861cc4205b
Binary files /dev/null and b/Geo/GeochatP-main/images/MBZUAI_logo.png differ
diff --git a/Geo/GeochatP-main/images/Oryx_logo.png b/Geo/GeochatP-main/images/Oryx_logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..745cbdf20a4e894a19e072c932f58e093ac89e15
Binary files /dev/null and b/Geo/GeochatP-main/images/Oryx_logo.png differ
diff --git a/Geo/GeochatP-main/images/architecture.png b/Geo/GeochatP-main/images/architecture.png
new file mode 100644
index 0000000000000000000000000000000000000000..4d00c49da47726d080a67e64ee508140543a55a1
--- /dev/null
+++ b/Geo/GeochatP-main/images/architecture.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9824f86304c64181844b789a5686b881daebdff2b662e7543087457600d9851d
+size 2185039
diff --git a/Geo/GeochatP-main/images/dataset.png b/Geo/GeochatP-main/images/dataset.png
new file mode 100644
index 0000000000000000000000000000000000000000..896c89fd022776b14d3149b11f69dca1cb6cc463
--- /dev/null
+++ b/Geo/GeochatP-main/images/dataset.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66049cf744404dd50e09173a373818dc9012cfce70ddc39dd9e4787c51606768
+size 2178146
diff --git a/Geo/GeochatP-main/images/examples.png b/Geo/GeochatP-main/images/examples.png
new file mode 100644
index 0000000000000000000000000000000000000000..08bd092de500456882a13751e2eca11d1e992735
--- /dev/null
+++ b/Geo/GeochatP-main/images/examples.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f709711a4a87633b7af9e86597f48f679db73c04251c7e14027901c2cb98ecce
+size 1710487
diff --git a/Geo/GeochatP-main/images/grounded.jpg b/Geo/GeochatP-main/images/grounded.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7bcded35713beca2df24e5c76ea066e45bfad54f
--- /dev/null
+++ b/Geo/GeochatP-main/images/grounded.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54da3de6b3bed90ecd6966075c265a6e77ef49cf5f362dd4f8d5a7b56dbad874
+size 1781213
diff --git a/Geo/GeochatP-main/images/iden.jpg b/Geo/GeochatP-main/images/iden.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f601574cd30fcda3e7fd8953d372125bcd341b83
--- /dev/null
+++ b/Geo/GeochatP-main/images/iden.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bab022c94989c795ef4defcd1c103eace3b972b72e8c79a21e3f006f2a70338d
+size 1074329
diff --git a/Geo/GeochatP-main/images/logo_geochat.png b/Geo/GeochatP-main/images/logo_geochat.png
new file mode 100644
index 0000000000000000000000000000000000000000..77c95ae0b9b1921fe4f770f108c341fa029487ce
--- /dev/null
+++ b/Geo/GeochatP-main/images/logo_geochat.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5ee933ef806ea2b5afc700ce144f766783d8bb9442834801a194993712f0025b
+size 300174
diff --git a/Geo/GeochatP-main/images/overview2.png b/Geo/GeochatP-main/images/overview2.png
new file mode 100644
index 0000000000000000000000000000000000000000..242fa7421986aee97551b28042efd15a07906ed2
--- /dev/null
+++ b/Geo/GeochatP-main/images/overview2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0919e70217d3ba2976a0899531444131666dde89180c58824b15c02111f653bb
+size 1761151
diff --git a/Geo/GeochatP-main/images/ref1.jpg b/Geo/GeochatP-main/images/ref1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b591fd4a7a477182fb09f40e425155a9d31458b9
--- /dev/null
+++ b/Geo/GeochatP-main/images/ref1.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd64d41ac2f9ed413e7341a86b109cec6b289d4b3b2fe93c484eb96d6b0d9db7
+size 1404643
diff --git a/Geo/GeochatP-main/images/ref_2.jpg b/Geo/GeochatP-main/images/ref_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b609d784a2eb9fb6356c3adb4638d8dfb99e04cd
--- /dev/null
+++ b/Geo/GeochatP-main/images/ref_2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0397c3b2b152c940c42f0088a0edd6ce52a1537be1473e07ffb0cbc672992f0b
+size 1750538
diff --git a/Geo/GeochatP-main/images/scene.jpg b/Geo/GeochatP-main/images/scene.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..91941c1794b2f124588567f8e9818a42cb0804ba
--- /dev/null
+++ b/Geo/GeochatP-main/images/scene.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb7340be3569d95ada355a8945b9af7f182cdc19d4f4e6e0c885bae6b7af3206
+size 1402791
diff --git a/Geo/GeochatP-main/images/teaser.png b/Geo/GeochatP-main/images/teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..833cf9719e05b9e6e45336d2fd2781a27c936fbc
--- /dev/null
+++ b/Geo/GeochatP-main/images/teaser.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:deb585cd00a836d35ee927bb40220b4920696c63e678b24079951f08603116bf
+size 644332
diff --git a/Geo/GeochatP-main/images/vqa.jpg b/Geo/GeochatP-main/images/vqa.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..73dd07aa5b3e21cdf3fc228a8f20410453b83030
--- /dev/null
+++ b/Geo/GeochatP-main/images/vqa.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf2e8338cde2a57841463688c991787e78e9f3cd7a8e60f19f07864f61c6da8c
+size 735179
diff --git a/Geo/GeochatP-main/playground/data/prompts/conversation/000_caps.txt b/Geo/GeochatP-main/playground/data/prompts/conversation/000_caps.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ac683b2a91e555b3045377a1a186a00d5cf29dac
--- /dev/null
+++ b/Geo/GeochatP-main/playground/data/prompts/conversation/000_caps.txt
@@ -0,0 +1 @@
+This is a view from above of harbor. A white ship anchored at harbor at the left. A white ship anchored at harbor at the left. A white ship anchored at harbor at the left. A white ship anchored at harbor at the left. A white ship anchored at harbor at the left. 3 white ships anchored at harbor at the center. A white ship anchored at harbor at the center. A white ship anchored at harbor at the center. 2 silver ships anchored at harbor at the left. A mostly gray ship anchored at harbor at the left. A mostly gray ship anchored at harbor at the left. A mostly gray ship anchored at harbor at the bottom right. A mostly gray ship anchored at harbor at the bottom right. 3 mostly gray ships anchored at harbor at the bottom right. 2 mostly gray ships anchored at harbor at the bottom right. A mostly gray ship anchored at harbor at the left. A mostly gray ship anchored at harbor at the bottom right. 2 mostly gray ships anchored at harbor at the bottom. 2 mostly gray ships anchored at harbor at the bottom. 2 mostly gray ships anchored at harbor at the bottom. A mostly gray ship anchored at harbor at the bottom. 2 mostly gray ships anchored at harbor at the center. A mostly gray ship anchored at harbor at the left. 3 harbor close to each other at bottom. 3 tennis-court close to each other at top. 5 gray small-vehicle at the bottom. 1 gray large-vehicle at the bottom left. 3 white small-vehicle at the bottom left. 8 mostly black small-vehicle at the bottom right. 2 gray small-vehicle at the center. 1 harbor at the left. 1 gray small-vehicle at the left. 1 black small-vehicle at the right. 6 black small-vehicle at the top. 1 swimming-pool at the top. 2 white large-vehicle at the top left. 7 gray small-vehicle at the top left
diff --git a/Geo/GeochatP-main/playground/data/prompts/conversation/000_conv.txt b/Geo/GeochatP-main/playground/data/prompts/conversation/000_conv.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f4952ec742f02a31930903b7b58835a87a871083
--- /dev/null
+++ b/Geo/GeochatP-main/playground/data/prompts/conversation/000_conv.txt
@@ -0,0 +1,53 @@
+Question: How many ships are anchored at the left of the harbor?
+Answer: There are five white ships anchored at the left of the harbor.
+
+Question: What is the color of the ships anchored at the left of the harbor?
+Answer: The ships anchored at the left of the harbor are white in color.
+
+Question: Can you describe the area at the center of the harbor in terms of anchored ships?
+Answer: In the center of the harbor, there are three ships which are white in color and two silver ships anchored.
+
+Question: How many mostly gray ships are anchored at the bottom right of the harbor?
+Answer: There are three mostly gray ships anchored at the bottom right of the harbor.
+
+Question: Whar is the size of the silver vehicles at the top left?
+Answer: The silver vehicles at the top left are normal sized.
+
+Question: Are there any tennis courts close to each other in the image?
+Answer: Yes, there are three tennis courts close to each other at the top of the image.
+
+Question: How many cars are at the bottom of the image?
+Answer: There are five cars at the bottom of the image, which are grey in color.
+
+Question: Can you identify the location of a gray truck?
+Answer: Yes, there is one gray truck at the bottom left of the image.
+
+Question: How many mostly black cars are at the bottom right of the image?
+Answer: There are eight mostly black cars at the bottom right of the image.
+
+Question: Is there a swimming pool in the image, and if so, where is it located?
+Answer: Yes, there is one swimming pool at the top of the image.
+
+Question: Can you describe the presence of white large vehicles in the image?
+Answer: There are two white large vehicles at the top left of the image.
+
+Question: What can you infer about the grouping of three harbors close to each other at the bottom of the image?
+Answer: The presence of three close-together harbors at the bottom suggests a cluster of maritime activity, potentially indicating a busy port area or waterfront location where ships dock.
+
+Question: What is the color of the mostly gray ships anchored at the bottom right of the harbor?
+Answer: The mostly gray ships anchored at the bottom right of the harbor are gray.
+
+Question: Can you identify any mostly black cars in the image, and if so, how many are there?
+Answer: Yes, there are eight mostly black cars in the image.
+
+Question: What is the relative size of the gray truck at the bottom left compared to the other vehicles?
+Answer: The gray truck at the bottom left is larger in size compared to the other vehicles.
+
+Question: How many white large vehicles are there at the top left, and what can you tell about their relative size?
+Answer: There are two white large vehicles at the top left, and they are larger in size compared to the smaller vehicles in the image.
+
+Question: Are the white ships anchored at the left of the harbor of the same size, or is there a size difference among them?
+Answer: The white ships anchored at the left of the harbor appear to be of the same size.
+
+Question: Given the presence of various ships, vehicles, and facilities like tennis courts and a swimming pool, what type of scene or environment does this image likely depict? What potential activities or interactions can you infer from the arrangement of these elements?
+Answer: This image likely represents an aerial view of a coastal area or harbor, characterized by a mix of maritime and recreational facilities. The numerous anchored ships suggest a bustling port, while the presence of tennis courts, a swimming pool, and various vehicles hints at a multifaceted waterfront space. It's possible that this area serves both commercial and recreational purposes, with ships being loaded or unloaded, and people engaging in leisure activities nearby. The variety of elements in this scene points to a dynamic and versatile waterfront environment.
diff --git a/Geo/GeochatP-main/playground/data/prompts/conversation/001_caps.txt b/Geo/GeochatP-main/playground/data/prompts/conversation/001_caps.txt
new file mode 100644
index 0000000000000000000000000000000000000000..40203035a4f23d92d25b9b1a90c98ca4dc518649
--- /dev/null
+++ b/Geo/GeochatP-main/playground/data/prompts/conversation/001_caps.txt
@@ -0,0 +1 @@
+This is a bird's-eye view of rectangular farmland. 1 airport at the center.
diff --git a/Geo/GeochatP-main/playground/data/prompts/conversation/001_conv.txt b/Geo/GeochatP-main/playground/data/prompts/conversation/001_conv.txt
new file mode 100644
index 0000000000000000000000000000000000000000..950792c8229aab7f908401e00f031e465c2ae06f
--- /dev/null
+++ b/Geo/GeochatP-main/playground/data/prompts/conversation/001_conv.txt
@@ -0,0 +1,8 @@
+Question: Can you describe the layout and arrangement of the farmland in the image? Does it appear to be organized in a specific pattern?
+Answer: The image depicts rectangular farmland, and the farmland appears to be organized in a rectangular pattern.
+
+Question: What does the description reveal about the size and scale of the airport at the center of the farmland?
+Answer : The description mentions the presence of an airport at the center of the farmland, but it does not provide specific details about the size or scale of the airport.
+
+Question: Given that this image depicts rectangular farmland with an airport at the center, how might the presence of an airport in the midst of agricultural land symbolize the intersection of traditional practices and modern connectivity? What opportunities and challenges could arise from this unique juxtaposition in terms of both agricultural productivity and the region's economic development?
+Answer: The airport amidst the farmland symbolizes the convergence of tradition and modernity. It offers opportunities for efficient agricultural exports and economic growth. However, it also raises challenges in preserving farmland. Ultimately, it represents the balance between local agriculture and global connectivity.
diff --git a/Geo/GeochatP-main/playground/data/prompts/conversation/system_message.txt b/Geo/GeochatP-main/playground/data/prompts/conversation/system_message.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d48fdccf4d8d4b38a469ccbf37130b2ee410efb7
--- /dev/null
+++ b/Geo/GeochatP-main/playground/data/prompts/conversation/system_message.txt
@@ -0,0 +1,6 @@
+You are an AI visual assistant, and you are seeing a single image. What you see are provided with sentences, describing the same image you are looking at. The sentences describe various objects present in the scene, their colors, relative sizes as well as relative positions on the image.
+Answer all questions as you are seeing the image. Design a conversation between you and a person asking about this photo. The answers should be in a tone that a visual AI assistant is seeing the image and answering the question. Ask diverse questions and give corresponding answers. Only give definite answers.
+Include questions asking about the visual content of the image, including the object types, counting the objects, object actions, object locations, relative positions between objects, the size of objects, color of objects, etc.
+(1) one can see the content in the image that the question asks about and can answer confidently.
+(2) one can determine confidently from the image that it is not in the image. Do not ask any question that cannot be answered confidently. Also include complex questions that are relevant to the content in the image, for example, asking about background knowledge of the objects in the image, asking to discuss about events happening in the image, etc. Again, do not ask about uncertain details.
+Provide detailed answers when answering complex questions. For example, give detailed examples or reasoning steps to make the content more convincing and well-organized. You can include multiple paragraphs if necessary. Do not output anything else other than the question answer pairs.
diff --git a/Geo/GeochatP-main/pyproject.toml b/Geo/GeochatP-main/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..e0559092ad1e83d766bf7657b5cf043f3fa18cb0
--- /dev/null
+++ b/Geo/GeochatP-main/pyproject.toml
@@ -0,0 +1,39 @@
+[build-system]
+requires = ["setuptools>=61.0"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "geochat"
+version = "1.1.1"
+description = "Grounded VLM for Remote Sensing"
+readme = "README.md"
+requires-python = ">=3.8"
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+]
+dependencies = [
+ "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy",
+ "requests", "sentencepiece", "tokenizers>=0.12.1",
+ "torch==2.0.1", "torchvision==0.15.2", "uvicorn", "wandb",
+ "shortuuid", "httpx==0.24.0",
+ "deepspeed==0.9.5",
+ "peft==0.4.0",
+ "transformers==4.31.0",
+ "accelerate==0.21.0",
+ "bitsandbytes==0.41.0",
+ "scikit-learn==1.2.2",
+ "sentencepiece==0.1.99",
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
+ "gradio_client==0.2.9"
+]
+
+[project.urls]
+"Homepage" = "https://github.com/mbzuai-oryx/GeoChat"
+"Bug Tracker" = "https://github.com/mbzuai-oryx/GeoChat/issues"
+
+[tool.setuptools.packages.find]
+exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
+
+[tool.wheel]
+exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
diff --git a/Geo/GeochatP-main/scripts/extract_mm_projector.py b/Geo/GeochatP-main/scripts/extract_mm_projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..45be31e896e9c087093bd9bcb6d355ec6dfd11ab
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/extract_mm_projector.py
@@ -0,0 +1,47 @@
+"""
+This is just a utility that I use to extract the projector for quantized models.
+It is NOT necessary at all to train, or run inference/serve demos.
+Use this script ONLY if you fully understand its implications.
+"""
+
+
+import os
+import argparse
+import torch
+import json
+from collections import defaultdict
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Extract MMProjector weights')
+ parser.add_argument('--model-path', type=str, help='model folder')
+ parser.add_argument('--output', type=str, help='output file')
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ keys_to_match = ['mm_projector']
+ ckpt_to_key = defaultdict(list)
+ try:
+ model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
+ for k, v in model_indices['weight_map'].items():
+ if any(key_match in k for key_match in keys_to_match):
+ ckpt_to_key[v].append(k)
+ except FileNotFoundError:
+ # Smaller models or model checkpoints saved by DeepSpeed.
+ v = 'pytorch_model.bin'
+ for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
+ if any(key_match in k for key_match in keys_to_match):
+ ckpt_to_key[v].append(k)
+
+ loaded_weights = {}
+
+ for ckpt_name, weight_keys in ckpt_to_key.items():
+ ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
+ for k in weight_keys:
+ loaded_weights[k] = ckpt[k]
+
+ torch.save(loaded_weights, args.output)
diff --git a/Geo/GeochatP-main/scripts/finetune.sh b/Geo/GeochatP-main/scripts/finetune.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c14f770b481a548c978daca4b42fc0f74aeebe13
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/finetune.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+# IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
+
+# Uncomment and set the following variables correspondingly to run this script:
+
+################## VICUNA ##################
+# PROMPT_VERSION=v1
+# MODEL_VERSION="vicuna-v1-3-7b"
+################## VICUNA ##################
+
+################## LLaMA-2 ##################
+# PROMPT_VERSION="llava_llama_2"
+# MODEL_VERSION="llama-2-7b-chat"
+################## LLaMA-2 ##################
+
+deepspeed llava/train/train_mem.py \
+ --deepspeed ./scripts/zero2.json \
+ --model_name_or_path ./checkpoints/$MODEL_VERSION \
+ --version $PROMPT_VERSION \
+ --data_path ./playground/data/llava_instruct_80k.json \
+ --image_folder /path/to/coco/train2017 \
+ --vision_tower openai/clip-vit-large-patch14 \
+ --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
+ --mm_vision_select_layer -2 \
+ --mm_use_im_start_end False \
+ --mm_use_im_patch_token False \
+ --bf16 True \
+ --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 16 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 50000 \
+ --save_total_limit 1 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --dataloader_num_workers 4 \
+ --lazy_preprocess True \
+ --report_to wandb
diff --git a/Geo/GeochatP-main/scripts/finetune_full_schedule.sh b/Geo/GeochatP-main/scripts/finetune_full_schedule.sh
new file mode 100644
index 0000000000000000000000000000000000000000..59a0d4aa4d8f391c5b5e62452c4e9ef38934b4a9
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/finetune_full_schedule.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+# IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
+
+# Uncomment and set the following variables correspondingly to run this script:
+
+################## VICUNA ##################
+# PROMPT_VERSION=v1
+# MODEL_VERSION="vicuna-v1-3-7b"
+################## VICUNA ##################
+
+################## LLaMA-2 ##################
+# PROMPT_VERSION="llava_llama_2"
+# MODEL_VERSION="llama-2-7b-chat"
+################## LLaMA-2 ##################
+
+deepspeed llava/train/train_mem.py \
+ --deepspeed ./scripts/zero2.json \
+ --model_name_or_path ./checkpoints/$MODEL_VERSION \
+ --version $PROMPT_VERSION \
+ --data_path ./playground/data/llava_instruct_158k.json \
+ --image_folder /path/to/coco/train2017 \
+ --vision_tower openai/clip-vit-large-patch14 \
+ --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
+ --mm_vision_select_layer -2 \
+ --mm_use_im_start_end False \
+ --mm_use_im_patch_token False \
+ --bf16 True \
+ --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \
+ --num_train_epochs 3 \
+ --per_device_train_batch_size 16 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 50000 \
+ --save_total_limit 1 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --dataloader_num_workers 4 \
+ --lazy_preprocess True \
+ --report_to wandb
diff --git a/Geo/GeochatP-main/scripts/finetune_lora.sh b/Geo/GeochatP-main/scripts/finetune_lora.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e987da57d3d9d60c4fe0adc5e61a211086f16c8e
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/finetune_lora.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+################## VICUNA ##################
+PROMPT_VERSION=v1
+MODEL_VERSION="vicuna-v1.5-7b"
+################## VICUNA ##################
+
+ deepspeed --master_port=$((RANDOM + 10000)) --include localhost:gpu_ids geochat/train/train_mem.py \
+ --deepspeed ./scripts/zero2.json \
+ --lora_enable True \
+ --model_name_or_path path/to/base/llavav1.5-7b \
+ --version $PROMPT_VERSION \
+ --data_path path/to/GeoChat_Instruct.json \
+ --image_folder /share/softwares/kartik/final_images_llava \
+ --vision_tower openai/clip-vit-large-patch14-336 \
+ --mm_projector_type mlp2x_gelu \
+ --pretrain_mm_mlp_adapter path/to/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5/mm_projector.bin \
+ --mm_vision_select_layer -2 \
+ --mm_use_im_start_end False \
+ --mm_use_im_patch_token False \
+ --image_aspect_ratio pad \
+ --bf16 True \
+ --output_dir path/to/checkpoints_dir \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 32 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "epoch" \
+ --save_steps 7000 \
+ --save_total_limit 1 \
+ --learning_rate 2e-4 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --lazy_preprocess True \
+ --dataloader_num_workers 16 \
+ --report_to wandb
diff --git a/Geo/GeochatP-main/scripts/finetune_qlora.sh b/Geo/GeochatP-main/scripts/finetune_qlora.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c2ed4c030cb7a3fff79f47a8e681f4df7c989100
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/finetune_qlora.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+
+# IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
+
+# Uncomment and set the following variables correspondingly to run this script:
+
+################## VICUNA ##################
+# PROMPT_VERSION=v1
+# MODEL_VERSION="vicuna-v1-3-7b"
+################## VICUNA ##################
+
+################## LLaMA-2 ##################
+# PROMPT_VERSION="llava_llama_2"
+# MODEL_VERSION="llama-2-7b-chat"
+################## LLaMA-2 ##################
+
+deepspeed llava/train/train_mem.py \
+ --deepspeed ./scripts/zero2.json \
+ --lora_enable True \
+ --bits 4 \
+ --model_name_or_path ./checkpoints/$MODEL_VERSION \
+ --version $PROMPT_VERSION \
+ --data_path ./playground/data/llava_instruct_80k.json \
+ --image_folder /path/to/coco/train2017 \
+ --vision_tower openai/clip-vit-large-patch14 \
+ --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
+ --mm_vision_select_layer -2 \
+ --mm_use_im_start_end False \
+ --mm_use_im_patch_token False \
+ --bf16 True \
+ --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 16 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 50000 \
+ --save_total_limit 1 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --lazy_preprocess True \
+ --dataloader_num_workers 4 \
+ --report_to wandb
diff --git a/Geo/GeochatP-main/scripts/finetune_sqa.sh b/Geo/GeochatP-main/scripts/finetune_sqa.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3ed50288c31c118cab22312ad02a559d45725490
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/finetune_sqa.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+# IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
+
+deepspeed llava/train/train_mem.py \
+ --deepspeed ./scripts/zero2.json \
+ --model_name_or_path lmsys/vicuna-13b-v1.3 \
+ --version $PROMPT_VERSION \
+ --data_path /Data/ScienceQA/data/scienceqa/llava_train_QCM-LEA.json \
+ --image_folder /Data/ScienceQA/data/scienceqa/images/train \
+ --vision_tower openai/clip-vit-large-patch14 \
+ --pretrain_mm_mlp_adapter ./checkpoints/huggingface/liuhaotian/llava-pretrain-vicuna-13b-v1.3/mm_projector.bin \
+ --mm_vision_select_layer -2 \
+ --mm_use_im_start_end False \
+ --mm_use_im_patch_token False \
+ --bf16 True \
+ --output_dir ./checkpoints/llava-vicuna-13b-v1.3-pretrain_lcs558k_plain-ScienceQA_QCM_LEA-12e \
+ --num_train_epochs 12 \
+ --per_device_train_batch_size 16 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 50000 \
+ --save_total_limit 1 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --dataloader_num_workers 4 \
+ --lazy_preprocess True \
+ --report_to wandb
diff --git a/Geo/GeochatP-main/scripts/merge_lora_weights.py b/Geo/GeochatP-main/scripts/merge_lora_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b39cc7beb12301379af7daebbb5553fa92093ea
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/merge_lora_weights.py
@@ -0,0 +1,22 @@
+import argparse
+from llava.model.builder import load_pretrained_model
+from llava.mm_utils import get_model_name_from_path
+
+
+def merge_lora(args):
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
+
+ model.save_pretrained(args.save_model_path)
+ tokenizer.save_pretrained(args.save_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, required=True)
+ parser.add_argument("--model-base", type=str, required=True)
+ parser.add_argument("--save-model-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ merge_lora(args)
diff --git a/Geo/GeochatP-main/scripts/pretrain.sh b/Geo/GeochatP-main/scripts/pretrain.sh
new file mode 100644
index 0000000000000000000000000000000000000000..83f263dd570e447b3b009542d26688ce936436af
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/pretrain.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+
+# IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
+
+# Uncomment and set the following variables correspondingly to run this script:
+
+# MODEL_VERSION=vicuna-v1-3-7b
+# MODEL_VERSION=llama-2-7b-chat
+
+########### DO NOT CHANGE ###########
+########### USE THIS FOR BOTH ###########
+PROMPT_VERSION=plain
+########### DO NOT CHANGE ###########
+
+deepspeed llava/train/train_mem.py \
+ --deepspeed ./scripts/zero2.json \
+ --model_name_or_path ./checkpoints/$MODEL_VERSION \
+ --version $PROMPT_VERSION \
+ --data_path /path/to/pretrain_data.json \
+ --image_folder /path/to/images \
+ --vision_tower openai/clip-vit-large-patch14 \
+ --tune_mm_mlp_adapter True \
+ --mm_vision_select_layer -2 \
+ --mm_use_im_start_end False \
+ --mm_use_im_patch_token False \
+ --bf16 True \
+ --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 16 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 24000 \
+ --save_total_limit 1 \
+ --learning_rate 2e-3 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --dataloader_num_workers 4 \
+ --lazy_preprocess True \
+ --report_to wandb
diff --git a/Geo/GeochatP-main/scripts/zero2.json b/Geo/GeochatP-main/scripts/zero2.json
new file mode 100644
index 0000000000000000000000000000000000000000..c95ebefe07b7d8d9fd0936a014679d07102cc270
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/zero2.json
@@ -0,0 +1,23 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "train_micro_batch_size_per_gpu": "auto",
+ "train_batch_size": "auto",
+ "gradient_accumulation_steps": "auto",
+ "zero_optimization": {
+ "stage": 2,
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto"
+ }
+}
\ No newline at end of file
diff --git a/Geo/GeochatP-main/scripts/zero3.json b/Geo/GeochatP-main/scripts/zero3.json
new file mode 100644
index 0000000000000000000000000000000000000000..6917317af62da757ca759a92b326ddfa65b203cc
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/zero3.json
@@ -0,0 +1,28 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "train_micro_batch_size_per_gpu": "auto",
+ "train_batch_size": "auto",
+ "gradient_accumulation_steps": "auto",
+ "zero_optimization": {
+ "stage": 3,
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true
+ }
+}
\ No newline at end of file
diff --git a/Geo/GeochatP-main/scripts/zero3_offload.json b/Geo/GeochatP-main/scripts/zero3_offload.json
new file mode 100644
index 0000000000000000000000000000000000000000..e0a54c2c2bc10f76458c42a43de0970a9251759f
--- /dev/null
+++ b/Geo/GeochatP-main/scripts/zero3_offload.json
@@ -0,0 +1,56 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "scheduler": {
+ "type": "WarmupLR",
+ "params": {
+ "warmup_min_lr": "auto",
+ "warmup_max_lr": "auto",
+ "warmup_num_steps": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "gather_16bit_weights_on_model_save": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "steps_per_print": 1e5,
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/README.md b/README.md
index f4a6a5663124c7ccb3dc2d2d3562d081f515a1c7..0ef84f6fa92841d37de2f35fdaa5850806d6c8ab 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,227 @@
+# GeoChat
: Grounded Large Vision-Language Model for Remote Sensing [CVPR-2024]
+
+
+
+
+#### [Kartik Kuckreja](https://www.linkedin.com/in/kartik-kuckreja-930531221/)\*, [Muhammad Sohail Danish](https://www.linkedin.com/in/muhammad-sohail-danish/)\*, [Muzammal Naseer](https://muzammal-naseer.com/), [Abhijit Das](https://sites.google.com/site/dasabhijit2048/home), [Salman Khan](https://salman-h-khan.github.io/) and [Fahad Khan](https://sites.google.com/view/fahadkhans/home)
+\* Equally contributing first authors
+
+#### **Mohamed bin Zayed University of AI, Birla Institute of Technology & Science, Australian National University, Linkoping University**
+
+[](https://mbzuai-oryx.github.io/GeoChat)
+[](https://arxiv.org/abs/2311.15826)
+[](https://youtu.be/KOKtkkKpNDk)
+
---
-title: GeochatP
-emoji: 📉
-colorFrom: purple
-colorTo: indigo
-sdk: gradio
-sdk_version: 5.22.0
-app_file: app.py
-pinned: false
-short_description: GeochatP
+
+## 📢 Latest Updates
+- Supplementary material for the accepted paper is available here: [Supplementary](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/geochat_supp.pdf).
+- **Feb-28-24**: We open source the code, model, dataset, and evaluation scripts.
+- **Feb-27-24**: GeoChat has been accepted to **CVPR-24** 🎉.
+- **Nov-28-23**: GeoChat paper is released [arxiv link](https://arxiv.org/abs/2311.15826). 🔥🔥
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
+##
Overview
+
+GeoChat is the first grounded Large Vision Language Model, specifically tailored to Remote Sensing(RS) scenarios. Unlike general-domain models, GeoChat excels in handling high-resolution RS imagery, employing region-level reasoning for comprehensive scene interpretation. Leveraging a newly created RS multimodal dataset, GeoChat is fine-tuned using the LLaVA-1.5 architecture. This results in robust zero-shot performance across various RS tasks, including image and region captioning, visual question answering, scene classification, visually grounded conversations, and referring object detection.
+
+---
+## Contents
+- [Install](#install)
+- [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md)
+- [Dataset](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json)
+- [Train](#train)
+- [Evaluation](#evaluation)
+
+## Install
+
+1. Clone this repository and navigate to GeoChat folder
+```bash
+git clone https://github.com/mbzuai-oryx/GeoChat.git
+cd GeoChat
+```
+
+2. Install Package
+```Shell
+conda create -n geochat python=3.10 -y
+conda activate geochat
+pip install --upgrade pip # enable PEP 660 support
+pip install -e .
+```
+
+3. Install additional packages for training cases
+```
+pip install ninja
+pip install flash-attn --no-build-isolation
+```
+
+### Upgrade to latest code base
+
+```Shell
+git pull
+pip uninstall transformers
+pip install -e .
+```
+
+## GeoChat Weights and Demo
+Please check out our [Model Zoo](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/MODEL_ZOO.md) for all public GeoChat checkpoints, and check [LoRA.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/LoRA.md) for instructions on how to run the demo and training.
+
+## Train
+
+GeoChat training consists of visual instruction tuning using GeoChat_Instruct Dataset: 318k Vicuna-generated multimodal instruction-following data, finetuned over the pretrained weights of LlaVA-v1.5.
+
+We train GeoChat on 3 A100 GPUs with 40GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`.
+
+### Hyperparameters
+We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.
+
+| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
+| --- | ---: | ---: | ---: | ---: | ---: |
+| GeoChat-7B | 144 | 2e-5 | 1 | 2048 | 0 |
+
+### Pretrain (feature alignment)
+
+We use the pretrained projector from LLaVAv1.5, which is trained on 558K subset of the LAION-CC-SBU dataset with BLIP captions. It takes around 3.5 hours for LLaVA-v1.5-7B.
+
+- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
+- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
+
+### Visual Instruction Tuning
+
+1. Prepare data
+
+Please download the annotation of the final mixture of our instruction tuning data [GeoChat_Instruct.json](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct/blob/main/GeoChat_Instruct.json), and download the split image zips from the [hugging face](https://huggingface.co/datasets/MBZUAI/GeoChat_Instruct). Save the multiple image zips in a single folder and run the following command to merge them:
+```Shell
+cat images_parta* > images.zip
+```
+Unzip the images.zip file to a folder and give the folder's path in [finetune_lora.sh](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
+
+2. Start training!
+
+Visual instruction tuning takes more time due to the increased resolution of CLIP to 504X504. It takes around ~25 hours to finetune GeoChat-7B on 3x A100 (40G).
+
+Training script with DeepSpeed ZeRO-3: [`finetune_lora.sh`](https://github.com/mbzuai-oryx/GeoChat/blob/main/scripts/finetune_lora.sh).
+
+Options to note:
+
+- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
+- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
+- `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
+- `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct).
+-
+## Evaluation
+
+We evaluate GeoChat on a diverse set of 7 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
+See [Evaluation.md](https://github.com/mbzuai-oryx/GeoChat/blob/main/docs/Evaluation.md).
+
+## 🏆 Contributions
+
+- **RS multimodal instruction following dataset.** We present a novel data generation pipeline, to leverage existing object detection dataset to create short descriptions of the images, followed by using Vicuna-v1.5 to create conversations using the generated text alone. Further, we add visual question-answering and scene classification abilities
+ using their corresponding datasets. This results in a total of 318k instruction pairs for RS domain.
+- **GeoChat.** Leveraging our dataset, we finetune LLaVA-1.5 to create the remote sensing-domain vision-language model - GeoChat. Our LoRA fine-tuning is efficient and avoids forgetting the necessary context embedded in fully-tuned LLaVA model, whose MLP projection is trained to align images into the word embedding space of the LLM (Vicuna-v1.5). This allows GeoChat to retain the conversation and instruction following abilities of LLaVA and extend its domain-knowledge to remote sensing tasks.
+
+- **Evaluation Benchmark.** We also address the lack of evaluation benchmarks to assess the capability of existing VLMs on remote-sensing conversations. To this end, we setup evaluation protocols for conversation grounding in RS, as well as a setup a suite of tasks to allow comparisons with future efforts in this direction. We show various supervised as well as zero-shot evaluations for different remote sensing tasks, including image captioning, visual question answering and scene classification to demonstrate the generalisability of GeoChat conversational VLM.
+
+---
+## 👁️💬 GeoChat : Grounded Large Vision-Language Model for Remote Sensing
+
+GeoChat can accomplish multiple tasks for remote-sensing (RS) image comprehension in a unified framework. Given suitable task tokens and user queries, the model can generate visually grounded responses (text with corresponding object locations - shown on top), visual question answering on images and regions (top left and bottom right, respectively) as well as scene classification (top right) and normal natural language conversations (bottom). This makes it the first RS VLM with grounding capability.
+
+
+
+
+
+---
+
+## 🛰️ GeoChat : Architecture
+
+An overview of GeoChat - the first grounded large vision-language model for remote sensing. Given an image input together with a user query, a visual backbone is first used to encode patch-level tokens at a higher resolution via interpolating positional encodings. A multi-layer perceptron (MLP) is used to adapt vision-tokens to language space suitable for input to a Large Language Model (Vicuna 1.5). Besides visual inputs, region locations can also be input to the model together with task-specific prompts that specify the desired task required by the user. Given this context, the LLM can generate natural language responses interleaved with corresponding object locations. GeoChat can perform multiple tasks as shown on top e.g., scene classification, image/region captioning, VQA and grounded conversations.
+
+
+
+
+
+---
+
+## 🔍 RS Multimodal Instruction Dataset
+
+Types of annotations available in the GeoChat instruction-set. For a given RS image, we obtain object attribute and relationship information, referring expressions and region captions along with their corresponding region annotations (shown over the image). This structured information is used to create the rich instruction-set with a total of 318k image-instruction pairs.
+
+
+
+
+
+
+
+## 🤖 Qualitative results of GeoChat
+
+Qualitative results of GeoChat. (left-right) Results are shown on grounding, referring object detection, and disaster/damage detection. The user can provide task-specific tokens (e.g., [grounding]) to shape model responses according to the desired behavior. The model can generate textual responses (right), only visual grounding (center) and both text and object groundings interleaved together (left). The model can also specify object types, object counts, object attributes and object relationships.
+
+
+
+
+---
+
+## 🤖 Visual Question Answering
+Qualitative examples for Visual Question Answering tasks. GeoChat is able to hold multi-turn conversations, based on various types of questions, including presence, count, complex comparisons and so on. It is able to detect objects and hold conversations against low resolution images as well.
+
+
+
+
+---
+
+## 🤖 Scene Classification
+Qualitative examples for scene classification. We give the model all the classes from the dataset and ask to choose only one.
+
+
+
+
+---
+
+## 🤖 Grounded Description
+When asked to describe the image with the special token '[grounding]', GeoChat outputs both the description of the image as well as the bounding boxes for all the objects detected.
+
+
+
+
+---
+
+## 🤖 Referring Expression
+When asked about an object as a referred expression, GeoChat is able to locate it and draw rotated bounding boxes around it correspondingly.
+
+
+
+
+
+
+
+---
+
+## 🤖 Region Caption
+Qualitative examples for region-based captioning. Given a bounding box, GeoChat is able to provide brief descriptions about the area or the object covered by the bounding box.
+
+
+
+
+---
+
+## 📜 Citation
+```bibtex
+ @article{kuckreja2023geochat,
+ title={GeoChat: Grounded Large Vision-Language Model for Remote Sensing},
+ author={Kuckreja, Kartik and Danish, Muhammad S. and Naseer, Muzammal and Das, Abhijit and Khan, Salman and Khan, Fahad S.},
+ journal={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ year={2024}
+ }
+```
+## 🙏 Acknowledgement
+We are thankful to LLaVA and Vicuna for releasing their models and code as open-source contributions.
+
+---
+[
](https://www.ival-mbzuai.com)
+[
](https://github.com/mbzuai-oryx)
+[
](https://mbzuai.ac.ae)
+F i x i n g H u g g i n g F a c e d e p l o y m e n t
+
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3fcaa009bbd66a363f6d7c926a1c1300e57daac
--- /dev/null
+++ b/app.py
@@ -0,0 +1,35 @@
+import torch
+import gradio as gr
+from torchvision import transforms
+from PIL import Image
+
+# Load model
+class MyModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ # Define layers here
+
+ def forward(self, x):
+ # Forward pass
+ return x
+
+model = MyModel()
+model.load_state_dict(torch.load("model.pth"))
+model.eval()
+
+# Define image preprocessing
+transform = transforms.Compose([
+ transforms.Resize((224, 224)),
+ transforms.ToTensor(),
+])
+
+# Define prediction function
+def predict(image):
+ image = transform(image).unsqueeze(0) # Add batch dimension
+ with torch.no_grad():
+ output = model(image)
+ return output.numpy().tolist()
+
+# Create Gradio interface
+iface = gr.Interface(fn=predict, inputs=gr.Image(), outputs="json")
+iface.launch()
diff --git a/geochat_demo.py b/geochat_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..939ef3444933f8b4e934fcdd5f4e9cc128acccf7
--- /dev/null
+++ b/geochat_demo.py
@@ -0,0 +1,706 @@
+import argparse
+import os
+import random
+from collections import defaultdict
+
+import cv2
+import re
+import math
+import numpy as np
+from PIL import Image
+import torch
+import html
+import gradio as gr
+
+import torchvision.transforms as T
+import torch.backends.cudnn as cudnn
+
+from geochat.conversation import conv_templates, Chat
+from geochat.model.builder import load_pretrained_model
+from geochat.mm_utils import get_model_name_from_path
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Demo")
+ # parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--gpu-id", type=str,default=0)
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--max-new-tokens", type=int, default=300)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ # args = parser.parse_args()
+ args = parser.parse_args()
+ return args
+
+
+random.seed(42)
+np.random.seed(42)
+torch.manual_seed(42)
+
+cudnn.benchmark = False
+cudnn.deterministic = True
+
+print('Initializing Chat')
+args = parse_args()
+# cfg = Config(args)
+
+model_name = get_model_name_from_path(args.model_path)
+tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+device = 'cuda:{}'.format(args.gpu_id)
+
+# model_config = cfg.model_cfg
+# model_config.device_8bit = args.gpu_id
+# model_cls = registry.get_model_class(model_config.arch)
+# model = model_cls.from_config(model_config).to(device)
+bounding_box_size = 100
+
+# vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
+# vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
+
+model = model.eval()
+
+CONV_VISION = conv_templates['llava_v1'].copy()
+
+def bbox_and_angle_to_polygon(x1, y1, x2, y2, a):
+ # Calculate center coordinates
+ x_ctr = (x1 + x2) / 2
+ y_ctr = (y1 + y2) / 2
+
+ # Calculate width and height
+ w = abs(x2 - x1)
+ h = abs(y2 - y1)
+
+ # Calculate the angle in radians
+ angle_rad = math.radians(a)
+
+ # Calculate coordinates of the four corners of the rotated bounding box
+ cos_a = math.cos(angle_rad)
+ sin_a = math.sin(angle_rad)
+
+ x1_rot = cos_a * (-w / 2) - sin_a * (-h / 2) + x_ctr
+ y1_rot = sin_a * (-w / 2) + cos_a * (-h / 2) + y_ctr
+
+ x2_rot = cos_a * (w / 2) - sin_a * (-h / 2) + x_ctr
+ y2_rot = sin_a * (w / 2) + cos_a * (-h / 2) + y_ctr
+
+ x3_rot = cos_a * (w / 2) - sin_a * (h / 2) + x_ctr
+ y3_rot = sin_a * (w / 2) + cos_a * (h / 2) + y_ctr
+
+ x4_rot = cos_a * (-w / 2) - sin_a * (h / 2) + x_ctr
+ y4_rot = sin_a * (-w / 2) + cos_a * (h / 2) + y_ctr
+
+ # Return the polygon coordinates
+ polygon_coords = np.array((x1_rot, y1_rot, x2_rot, y2_rot, x3_rot, y3_rot, x4_rot, y4_rot))
+
+ return polygon_coords
+
+def rotate_bbox(top_right, bottom_left, angle_degrees):
+ # Convert angle to radians
+ angle_radians = np.radians(angle_degrees)
+
+ # Calculate the center of the rectangle
+ center = ((top_right[0] + bottom_left[0]) / 2, (top_right[1] + bottom_left[1]) / 2)
+
+ # Calculate the width and height of the rectangle
+ width = top_right[0] - bottom_left[0]
+ height = top_right[1] - bottom_left[1]
+
+ # Create a rotation matrix
+ rotation_matrix = cv2.getRotationMatrix2D(center, angle_degrees, 1)
+
+ # Create an array of the rectangle corners
+ rectangle_points = np.array([[bottom_left[0], bottom_left[1]],
+ [top_right[0], bottom_left[1]],
+ [top_right[0], top_right[1]],
+ [bottom_left[0], top_right[1]]], dtype=np.float32)
+
+ # Rotate the rectangle points
+ rotated_rectangle = cv2.transform(np.array([rectangle_points]), rotation_matrix)[0]
+
+ return rotated_rectangle
+def extract_substrings(string):
+ # first check if there is no-finished bracket
+ index = string.rfind('}')
+ if index != -1:
+ string = string[:index + 1]
+
+ pattern = r'(.*?)\}(?!<)'
+ matches = re.findall(pattern, string)
+ substrings = [match for match in matches]
+
+ return substrings
+
+
+def is_overlapping(rect1, rect2):
+ x1, y1, x2, y2 = rect1
+ x3, y3, x4, y4 = rect2
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
+
+
+def computeIoU(bbox1, bbox2):
+ x1, y1, x2, y2 = bbox1
+ x3, y3, x4, y4 = bbox2
+ intersection_x1 = max(x1, x3)
+ intersection_y1 = max(y1, y3)
+ intersection_x2 = min(x2, x4)
+ intersection_y2 = min(y2, y4)
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
+ union_area = bbox1_area + bbox2_area - intersection_area
+ iou = intersection_area / union_area
+ return iou
+
+
+def save_tmp_img(visual_img):
+ file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
+ file_path = "/tmp/gradio" + file_name
+ visual_img.save(file_path)
+ return file_path
+
+
+def mask2bbox(mask):
+ if mask is None:
+ return ''
+ mask = mask.resize([100, 100], resample=Image.NEAREST)
+ mask = np.array(mask)[:, :, 0]
+
+ rows = np.any(mask, axis=1)
+ cols = np.any(mask, axis=0)
+
+ if rows.sum():
+ # Get the top, bottom, left, and right boundaries
+ rmin, rmax = np.where(rows)[0][[0, -1]]
+ cmin, cmax = np.where(cols)[0][[0, -1]]
+ bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
+ else:
+ bbox = ''
+
+ return bbox
+
+
+def escape_markdown(text):
+ # List of Markdown special characters that need to be escaped
+ md_chars = ['<', '>']
+
+ # Escape each special character
+ for char in md_chars:
+ text = text.replace(char, '\\' + char)
+
+ return text
+
+
+def reverse_escape(text):
+ md_chars = ['\\<', '\\>']
+
+ for char in md_chars:
+ text = text.replace(char, char[1:])
+
+ return text
+
+
+colors = [
+ (255, 0, 0),
+ (0, 255, 0),
+ (0, 0, 255),
+ (210, 210, 0),
+ (255, 0, 255),
+ (0, 255, 255),
+ (114, 128, 250),
+ (0, 165, 255),
+ (0, 128, 0),
+ (144, 238, 144),
+ (238, 238, 175),
+ (255, 191, 0),
+ (0, 128, 0),
+ (226, 43, 138),
+ (255, 0, 255),
+ (0, 215, 255),
+]
+
+color_map = {
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
+ color_id, color in enumerate(colors)
+}
+
+used_colors = colors
+
+
+def visualize_all_bbox_together(image, generation):
+ if image is None:
+ return None, ''
+
+ generation = html.unescape(generation)
+
+ image_width, image_height = image.size
+ image = image.resize([500, int(500 / image_width * image_height)])
+ image_width, image_height = image.size
+
+ string_list = extract_substrings(generation)
+ if string_list: # it is grounding or detection
+ mode = 'all'
+ entities = defaultdict(list)
+ i = 0
+ j = 0
+ for string in string_list:
+ try:
+ obj, string = string.split('
')
+ except ValueError:
+ print('wrong string: ', string)
+ continue
+ if "}{" in string:
+ string=string.replace("}{","}{")
+ bbox_list = string.split('')
+ flag = False
+ for bbox_string in bbox_list:
+ integers = re.findall(r'-?\d+', bbox_string)
+ if len(integers)==4:
+ angle=0
+ else:
+ angle=integers[4]
+ integers=integers[:-1]
+
+ if len(integers) == 4:
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
+ left = x0 / bounding_box_size * image_width
+ bottom = y0 / bounding_box_size * image_height
+ right = x1 / bounding_box_size * image_width
+ top = y1 / bounding_box_size * image_height
+
+ entities[obj].append([left, bottom, right, top,angle])
+
+ j += 1
+ flag = True
+ if flag:
+ i += 1
+ else:
+ integers = re.findall(r'-?\d+', generation)
+ # if len(integers)==4:
+ angle=0
+ # else:
+ # angle=integers[4]
+ integers=integers[:-1]
+ if len(integers) == 4: # it is refer
+ mode = 'single'
+
+ entities = list()
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
+ left = x0 / bounding_box_size * image_width
+ bottom = y0 / bounding_box_size * image_height
+ right = x1 / bounding_box_size * image_width
+ top = y1 / bounding_box_size * image_height
+ entities.append([left, bottom, right, top,angle])
+ else:
+ # don't detect any valid bbox to visualize
+ return None, ''
+
+ if len(entities) == 0:
+ return None, ''
+
+ if isinstance(image, Image.Image):
+ image_h = image.height
+ image_w = image.width
+ image = np.array(image)
+
+ elif isinstance(image, str):
+ if os.path.exists(image):
+ pil_img = Image.open(image).convert("RGB")
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
+ image_h = pil_img.height
+ image_w = pil_img.width
+ else:
+ raise ValueError(f"invaild image path, {image}")
+ elif isinstance(image, torch.Tensor):
+
+ image_tensor = image.cpu()
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
+ pil_img = T.ToPILImage()(image_tensor)
+ image_h = pil_img.height
+ image_w = pil_img.width
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
+ else:
+ raise ValueError(f"invalid image format, {type(image)} for {image}")
+
+ indices = list(range(len(entities)))
+
+ new_image = image.copy()
+
+ previous_bboxes = []
+ # size of text
+ text_size = 0.4
+ # thickness of text
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
+ box_line = 2
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
+ base_height = int(text_height * 0.675)
+ text_offset_original = text_height - base_height
+ text_spaces = 2
+
+ # num_bboxes = sum(len(x[-1]) for x in entities)
+ used_colors = colors # random.sample(colors, k=num_bboxes)
+
+ color_id = -1
+ for entity_idx, entity_name in enumerate(entities):
+ if mode == 'single' or mode == 'identify':
+ bboxes = entity_name
+ bboxes = [bboxes]
+ else:
+ bboxes = entities[entity_name]
+ color_id += 1
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm,angle) in enumerate(bboxes):
+ skip_flag = False
+ orig_x1, orig_y1, orig_x2, orig_y2,angle = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm), int(angle)
+
+ color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
+ top_right=(orig_x1,orig_y1)
+ bottom_left=(orig_x2,orig_y2)
+ angle=angle
+ rotated_bbox = rotate_bbox(top_right, bottom_left, angle)
+ new_image=cv2.polylines(new_image, [rotated_bbox.astype(np.int32)], isClosed=True,thickness=2, color=color)
+
+ # new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
+
+ if mode == 'all':
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
+
+ x1 = orig_x1 - l_o
+ y1 = orig_y1 - l_o
+
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
+ x1 = orig_x1 + r_o
+
+ # add text background
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
+ text_line)
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
+ text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
+
+ for prev_bbox in previous_bboxes:
+ if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
+ prev_bbox['phrase'] == entity_name:
+ skip_flag = True
+ break
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
+
+ if text_bg_y2 >= image_h:
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
+ text_bg_y2 = image_h
+ y1 = image_h
+ break
+ if not skip_flag:
+ alpha = 0.5
+ for i in range(text_bg_y1, text_bg_y2):
+ for j in range(text_bg_x1, text_bg_x2):
+ if i < image_h and j < image_w:
+ if j < text_bg_x1 + 1.35 * c_width:
+ # original color
+ bg_color = color
+ else:
+ # white
+ bg_color = [255, 255, 255]
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
+ np.uint8)
+
+ cv2.putText(
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
+ cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
+ )
+
+ previous_bboxes.append(
+ {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
+
+ if mode == 'all':
+ def color_iterator(colors):
+ while True:
+ for color in colors:
+ yield color
+
+ color_gen = color_iterator(colors)
+
+ # Add colors to phrases and remove
+ def colored_phrases(match):
+ phrase = match.group(1)
+ color = next(color_gen)
+ return f'{phrase}'
+
+ generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|', '', generation)
+ generation_colored = re.sub(r'(.*?)
', colored_phrases, generation)
+ else:
+ generation_colored = ''
+
+ pil_image = Image.fromarray(new_image)
+ return pil_image, generation_colored
+
+
+def gradio_reset(chat_state, img_list):
+ if chat_state is not None:
+ chat_state.messages = []
+ if img_list is not None:
+ img_list = []
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
+ interactive=True), chat_state, img_list
+
+
+def image_upload_trigger(upload_flag, replace_flag, img_list):
+ # set the upload flag to true when receive a new image.
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
+ upload_flag = 1
+ if img_list:
+ replace_flag = 1
+ return upload_flag, replace_flag
+
+
+def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
+ # set the upload flag to true when receive a new image.
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
+ upload_flag = 1
+ if img_list or replace_flag == 1:
+ replace_flag = 1
+
+ return upload_flag, replace_flag
+
+
+def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
+ if len(user_message) == 0:
+ text_box_show = 'Input should not be empty!'
+ else:
+ text_box_show = ''
+
+ if isinstance(gr_img, dict):
+ gr_img, mask = gr_img['image'], gr_img['mask']
+ else:
+ mask = None
+
+ if '[identify]' in user_message:
+ # check if user provide bbox in the text input
+ integers = re.findall(r'-?\d+', user_message)
+ if len(integers) != 4: # no bbox in text
+ bbox = mask2bbox(mask)
+ user_message = user_message + bbox
+
+ if chat_state is None:
+ chat_state = CONV_VISION.copy()
+
+ if upload_flag:
+ if replace_flag:
+ chat_state = CONV_VISION.copy() # new image, reset everything
+ replace_flag = 0
+ chatbot = []
+ img_list = []
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
+ upload_flag = 0
+
+ chat.ask(user_message, chat_state)
+
+ chatbot = chatbot + [[user_message, None]]
+
+ if '[identify]' in user_message:
+ visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
+ if visual_img is not None:
+ file_path = save_tmp_img(visual_img)
+ chatbot = chatbot + [[(file_path,), None]]
+
+ return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
+
+
+# def gradio_answer(chatbot, chat_state, img_list, temperature):
+# llm_message = chat.answer(conv=chat_state,
+# img_list=img_list,
+# temperature=temperature,
+# max_new_tokens=500,
+# max_length=2000)[0]
+# chatbot[-1][1] = llm_message
+# return chatbot, chat_state
+
+
+def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
+ if len(img_list) > 0:
+ if not isinstance(img_list[0], torch.Tensor):
+ chat.encode_img(img_list)
+ streamer = chat.stream_answer(conv=chat_state,
+ img_list=img_list,
+ temperature=temperature,
+ max_new_tokens=500,
+ max_length=2000)
+ # chatbot[-1][1] = output
+ # chat_state.messages[-1][1] = ''
+
+ output = ''
+ for new_output in streamer:
+ # print(new_output)
+ output=output+new_output
+ print(output)
+ # if "{" in output:
+ # chatbot[-1][1]="Grounding and referring expression is still under work."
+ # else:
+ output = escape_markdown(output)
+ # output += escapped
+ chatbot[-1][1] = output
+ yield chatbot, chat_state
+ chat_state.messages[-1][1] = ''
+ return chatbot, chat_state
+
+
+def gradio_visualize(chatbot, gr_img):
+ if isinstance(gr_img, dict):
+ gr_img, mask = gr_img['image'], gr_img['mask']
+
+ unescaped = reverse_escape(chatbot[-1][1])
+ visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
+ if visual_img is not None:
+ if len(generation_color):
+ chatbot[-1][1] = generation_color
+ file_path = save_tmp_img(visual_img)
+ chatbot = chatbot + [[None, (file_path,)]]
+
+ return chatbot
+
+
+def gradio_taskselect(idx):
+ prompt_list = [
+ '',
+ 'Classify the image in the following classes: ',
+ '[identify] what is this ',
+ ]
+ instruct_list = [
+ '**Hint:** Type in whatever you want',
+ '**Hint:** Type in the classes you want the model to classify in',
+ '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
+ ]
+ return prompt_list[idx], instruct_list[idx]
+
+
+
+
+chat = Chat(model, image_processor,tokenizer, device=device)
+
+
+title = """GeoChat Demo
"""
+description = 'Welcome to Our GeoChat Chatbot Demo!'
+article = """"""
+# article = """
"""
+
+introduction = '''
+1. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
+2. No Tag: Input whatever you want and CLICK **Send** without any tagging
+
+You can also simply chat in free form!
+'''
+
+
+text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
+ scale=12)
+with gr.Blocks() as demo:
+ gr.Markdown(title)
+ # gr.Markdown(description)
+ gr.Markdown(article)
+
+ with gr.Row():
+ with gr.Column(scale=0.5):
+ image = gr.Image(type="pil", tool='sketch', brush_radius=20)
+
+ temperature = gr.Slider(
+ minimum=0.1,
+ maximum=1.5,
+ value=0.6,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+
+ clear = gr.Button("Restart")
+
+ gr.Markdown(introduction)
+
+ with gr.Column():
+ chat_state = gr.State(value=None)
+ img_list = gr.State(value=[])
+ chatbot = gr.Chatbot(label='GeoChat')
+
+ dataset = gr.Dataset(
+ components=[gr.Textbox(visible=False)],
+ samples=[['No Tag'], ['Scene Classification'],['Identify']],
+ type="index",
+ label='Task Shortcuts',
+ )
+ task_inst = gr.Markdown('**Hint:** Upload your image and chat')
+ with gr.Row():
+ text_input.render()
+ send = gr.Button("Send", variant='primary', size='sm', scale=1)
+
+ upload_flag = gr.State(value=0)
+ replace_flag = gr.State(value=0)
+ image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
+
+ with gr.Row():
+ with gr.Column():
+ gr.Examples(examples=[
+ ["demo_images/train_2956_0001.png", "Where are the airplanes located and what is their type?", upload_flag, replace_flag,
+ img_list],
+ ["demo_images/7292.JPG", "How many buildings are flooded?", upload_flag,
+ replace_flag, img_list],
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
+ outputs=[upload_flag, replace_flag])
+ with gr.Column():
+ gr.Examples(examples=[
+ ["demo_images/church_183.png", "Classify the image in the following classes: Church, Beach, Dense Residential, Storage Tanks.",
+ upload_flag, replace_flag, img_list],
+ ["demo_images/04444.png", "[identify] what is this {<8><26><22><37>}", upload_flag,
+ replace_flag, img_list],
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
+ outputs=[upload_flag, replace_flag])
+
+ dataset.click(
+ gradio_taskselect,
+ inputs=[dataset],
+ outputs=[text_input, task_inst],
+ show_progress="hidden",
+ postprocess=False,
+ queue=False,
+ )
+
+ text_input.submit(
+ gradio_ask,
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
+ ).success(
+ gradio_stream_answer,
+ [chatbot, chat_state, img_list, temperature],
+ [chatbot, chat_state]
+ ).success(
+ gradio_visualize,
+ [chatbot, image],
+ [chatbot],
+ queue=False,
+ )
+
+ send.click(
+ gradio_ask,
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
+ ).success(
+ gradio_stream_answer,
+ [chatbot, chat_state, img_list, temperature],
+ [chatbot, chat_state]
+ ).success(
+ gradio_visualize,
+ [chatbot, image],
+ [chatbot],
+ queue=False,
+ )
+
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
+
+demo.launch(share=True, enable_queue=True,server_name='0.0.0.0')
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..e0559092ad1e83d766bf7657b5cf043f3fa18cb0
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,39 @@
+[build-system]
+requires = ["setuptools>=61.0"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "geochat"
+version = "1.1.1"
+description = "Grounded VLM for Remote Sensing"
+readme = "README.md"
+requires-python = ">=3.8"
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+]
+dependencies = [
+ "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy",
+ "requests", "sentencepiece", "tokenizers>=0.12.1",
+ "torch==2.0.1", "torchvision==0.15.2", "uvicorn", "wandb",
+ "shortuuid", "httpx==0.24.0",
+ "deepspeed==0.9.5",
+ "peft==0.4.0",
+ "transformers==4.31.0",
+ "accelerate==0.21.0",
+ "bitsandbytes==0.41.0",
+ "scikit-learn==1.2.2",
+ "sentencepiece==0.1.99",
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
+ "gradio_client==0.2.9"
+]
+
+[project.urls]
+"Homepage" = "https://github.com/mbzuai-oryx/GeoChat"
+"Bug Tracker" = "https://github.com/mbzuai-oryx/GeoChat/issues"
+
+[tool.setuptools.packages.find]
+exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
+
+[tool.wheel]
+exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]