huangjy-pku commited on
Commit
7978a78
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.glb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ logs/
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LEO
3
+ emoji: 🦁
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.10.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ from utils import *
6
+
7
+
8
+ with gr.Blocks(title='LEO Demo') as demo:
9
+ gr.HTML(value="<h1 align='center'>An Embodied Generalist Agent in 3D World</h1>")
10
+ gr.HTML(value="<div align='center' style='margin-top:-1em; margin-bottom:-1em;'><img src='/file=assets/leo.svg' width='4%'></div>")
11
+ # gr.HTML(value="<img src='/file=assets/teaser.png' alt='Teaser' width='760px' style='display: block; margin: auto;'>")
12
+ gr.HTML(value="<p align='center' style='font-size: 1.2em; color: #485fc7;'><a href='https://arxiv.org/abs/2311.12871' target='_blank'>arXiv</a> | <a href='https://embodied-generalist.github.io/' target='_blank'>Project Page</a> | <a href='https://github.com/embodied-generalist/embodied-generalist' target='_blank'>Code</a></p>")
13
+ gr.HTML(value="<p align='center' style='font-size: 1.15em;'><i>LEO: an embodied generalist agent capable of perceiving, grounding, reasoning, planning, and acting in 3D world.</i></p>")
14
+
15
+ with gr.Row():
16
+ with gr.Column(scale=5):
17
+ dropdown_scene = gr.Dropdown(
18
+ choices=MESH_NAMES,
19
+ value=MESH_NAMES[0],
20
+ interactive=True,
21
+ label='Select a 3D scene',
22
+ )
23
+ model_3d = gr.Model3D(
24
+ value=os.path.join(MESH_DIR, f'{MESH_NAMES[0]}.glb'),
25
+ clear_color=[0.0, 0.0, 0.0, 0.0],
26
+ label='3D Scene',
27
+ camera_position=(90, 30, 10),
28
+ height=659,
29
+ )
30
+ gr.HTML(
31
+ """<center><strong>
32
+ 👆 SCROLL and DRAG on the 3D Scene
33
+ to zoom in/out and rotate. Press CTRL and DRAG to pan.
34
+ </strong></center>
35
+ """
36
+ )
37
+ with gr.Column(scale=5):
38
+ dropdown_conversation_mode = gr.Dropdown(
39
+ choices=['Single-round mode', 'Multi-round mode'],
40
+ value='Single-round mode',
41
+ interactive=True,
42
+ label='Select conversation mode',
43
+ )
44
+ chatbot = gr.Chatbot(label='Chat with LEO')
45
+ with gr.Row():
46
+ with gr.Column(scale=8):
47
+ user_chat_input = gr.Textbox(
48
+ placeholder="Enter text here to chat with LEO",
49
+ show_label=False,
50
+ autofocus=True,
51
+ )
52
+ with gr.Column(scale=2, min_width=0):
53
+ send_button = gr.Button('Send', variant='primary', scale=2)
54
+ with gr.Row():
55
+ upvote_button = gr.Button(value='👍 Upvote', interactive=False)
56
+ downvote_button = gr.Button(value='👎 Downvote', interactive=False)
57
+ flag_button = gr.Button(value='⚠️ Flag', interactive=False)
58
+ clear_button = gr.Button(value='🗑️ Clear', interactive=False)
59
+ with gr.Row():
60
+ with gr.Accordion(label="Examples for user instruction:", open=True):
61
+ gr.Examples(
62
+ examples=[
63
+ ["How many armchairs are there in this room?"],
64
+ ["Is there a radio in the room?"],
65
+ ["Where is the wardrobe located?TODO"],
66
+ ["What is the shape of the shelf in front of the picture?TODO"],
67
+ ["Plan for the task: Tidy up and arrange the nursery room.TODO"],
68
+ ],
69
+ inputs=user_chat_input,
70
+ )
71
+
72
+ # generation_config
73
+ with gr.Accordion('Parameters', open=False):
74
+ repetition_penalty = gr.Slider(
75
+ minimum=0.0,
76
+ maximum=10.0,
77
+ value=3.0,
78
+ step=1.0,
79
+ interactive=True,
80
+ label='Repetition penalty',
81
+ )
82
+ length_penalty = gr.Slider(
83
+ minimum=0.0,
84
+ maximum=10.0,
85
+ value=1.0,
86
+ step=1.0,
87
+ interactive=True,
88
+ label="Length penalty",
89
+ )
90
+ gr.Markdown("### Terms of Service")
91
+ gr.HTML(
92
+ """By using this service, users are required to agree to the following terms:
93
+ the service is a research preview intended for non-commercial use only
94
+ and may collect user dialogue data for future research."""
95
+ )
96
+ gr.Markdown("### Acknowledgment")
97
+ gr.HTML(
98
+ """Template adapted from <a href="https://llava.hliu.cc/">LLaVA</a> and
99
+ <a href="http://sled-whistler.eecs.umich.edu:7777/">LLM-Grounder</a>."""
100
+ )
101
+
102
+ # Event handling
103
+ button_list = [upvote_button, downvote_button, flag_button, clear_button]
104
+
105
+ dropdown_scene.change(
106
+ fn=change_scene,
107
+ inputs=[dropdown_scene],
108
+ outputs=[model_3d, chatbot],
109
+ queue=False,
110
+ )
111
+
112
+ dropdown_conversation_mode.change(
113
+ fn=clear_history,
114
+ inputs=[],
115
+ outputs=[chatbot, user_chat_input] + button_list,
116
+ queue=False,
117
+ )
118
+
119
+ user_chat_input.submit(
120
+ fn=receive_instruction,
121
+ inputs=[chatbot, user_chat_input],
122
+ outputs=[chatbot, user_chat_input, send_button] + button_list,
123
+ queue=False,
124
+ ).then(
125
+ fn=generate_response,
126
+ inputs=[
127
+ chatbot,
128
+ dropdown_scene,
129
+ dropdown_conversation_mode,
130
+ repetition_penalty,
131
+ length_penalty,
132
+ ],
133
+ outputs=[chatbot, send_button] + button_list,
134
+ scroll_to_output=True,
135
+ )
136
+
137
+ send_button.click(
138
+ fn=receive_instruction,
139
+ inputs=[chatbot, user_chat_input],
140
+ outputs=[chatbot, user_chat_input, send_button] + button_list,
141
+ queue=False,
142
+ ).then(
143
+ fn=generate_response,
144
+ inputs=[
145
+ chatbot,
146
+ dropdown_scene,
147
+ dropdown_conversation_mode,
148
+ repetition_penalty,
149
+ length_penalty,
150
+ ],
151
+ outputs=[chatbot, send_button] + button_list,
152
+ scroll_to_output=True,
153
+ )
154
+
155
+ upvote_button.click(
156
+ upvote_response,
157
+ [chatbot, dropdown_scene, dropdown_conversation_mode],
158
+ [user_chat_input, upvote_button, downvote_button, flag_button],
159
+ queue=False,
160
+ )
161
+ downvote_button.click(
162
+ downvote_response,
163
+ [chatbot, dropdown_scene, dropdown_conversation_mode],
164
+ [user_chat_input, upvote_button, downvote_button, flag_button],
165
+ queue=False,
166
+ )
167
+ flag_button.click(
168
+ flag_response,
169
+ [chatbot, dropdown_scene, dropdown_conversation_mode],
170
+ [user_chat_input, upvote_button, downvote_button, flag_button],
171
+ queue=False,
172
+ )
173
+ clear_button.click(
174
+ fn=clear_history,
175
+ inputs=[],
176
+ outputs=[chatbot, user_chat_input] + button_list,
177
+ queue=False,
178
+ )
179
+
180
+
181
+ demo.queue().launch(share=True, allowed_paths=['assets'])
assets/leo.svg ADDED
assets/obj_features/3RScan-0cac759b-8d6f-2d13-8e3b-2e3bc1ee1158.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5642bb84ba04d10c5aa199dbcd5ea1ab01df0d2517719a2a2e943381f11bd25b
3
+ size 1002083
assets/obj_features/3RScan-0cac760d-8d6f-2d13-8ea2-109ce4da9ac9.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eae2324173b34331b6dad37c89a75db275d1d23fbb1f1d7478573085cdf1d733
3
+ size 1002083
assets/obj_features/3RScan-752cc597-920c-26f5-8c1b-a8a5c90a21d7.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50a9e124ea270cbe23b59fbddb527d5cf61005c657bd3f5f41535998ba84d9b6
3
+ size 1002083
assets/scene_meshes/3RScan-0cac759b-8d6f-2d13-8e3b-2e3bc1ee1158.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d197483b3be1f6f1395faa3a8b413ee23335fd8f081456b63db96f5928291b1
3
+ size 9632176
assets/scene_meshes/3RScan-0cac760d-8d6f-2d13-8ea2-109ce4da9ac9.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:419988aa4781ec7d0a06e9087c8a918a20c389c50b210daa6b3c47be981b28ac
3
+ size 9445868
assets/scene_meshes/3RScan-752cc597-920c-26f5-8c1b-a8a5c90a21d7.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0db74afa2648056c839840ba8a11d832012b6f70114668835c2da82d5ae07ec2
3
+ size 11326324
model/cfg.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use_ckpt: hf
2
+ hf_ckpt_path: [huangjy-pku/embodied-generalist, weights/leo_noact_hf.pth]
3
+ local_ckpt_path: /mnt/huangjiangyong/leo/hf_assets/weights/leo_noact_lora.pth
4
+ model:
5
+ name: LeoAgentLLM
6
+ # vision modules omitted
7
+ llm:
8
+ name: Vicuna7B
9
+ use_ckpt: hf
10
+ hf_cfg_path: huangjy-pku/vicuna-7b
11
+ local_cfg_path: /mnt/huangjiangyong/vicuna-7b
12
+ truncation_side: right
13
+ prompt: ""
14
+ max_out_len: 256
15
+ max_context_len: 256 # for prompt_after_obj
16
+ lora:
17
+ flag: True
18
+ rank: 16
19
+ alpha: 16
20
+ target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
21
+ dropout: 0.0
model/leo_agent.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import snapshot_download
4
+ from peft import get_peft_model, LoraConfig
5
+ from transformers import LlamaForCausalLM, LlamaTokenizer
6
+
7
+
8
+ def disabled_train(self, mode=True):
9
+ """Overwrite model.train with this function to make sure train/eval mode
10
+ does not change anymore."""
11
+ return self
12
+
13
+
14
+ class LeoAgentLLM(nn.Module):
15
+ def __init__(self, cfg):
16
+ super().__init__()
17
+ if hasattr(cfg, 'model'):
18
+ cfg = cfg.model
19
+
20
+ # LLM
21
+ if cfg.llm.use_ckpt == 'hf':
22
+ llm_cfg_path = snapshot_download(cfg.llm.hf_cfg_path)
23
+ else:
24
+ llm_cfg_path = cfg.llm.local_cfg_path
25
+ self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, use_fast=False,
26
+ truncation_side=cfg.llm.truncation_side)
27
+ self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
28
+ self.llm_tokenizer.add_special_tokens({'bos_token': '<s>'})
29
+ self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
30
+ self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
31
+ self.llm_model = LlamaForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16)
32
+ self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
33
+
34
+ for param in self.llm_model.parameters():
35
+ param.requires_grad = False
36
+ self.llm_model.eval()
37
+ self.llm_model.train = disabled_train
38
+
39
+ # LoRA-based LLM fine-tuning
40
+ if cfg.llm.lora.flag:
41
+ lora_config = LoraConfig(
42
+ r=cfg.llm.lora.rank,
43
+ lora_alpha=cfg.llm.lora.alpha,
44
+ target_modules=cfg.llm.lora.target_modules,
45
+ lora_dropout=cfg.llm.lora.dropout,
46
+ bias='none',
47
+ modules_to_save=[],
48
+ )
49
+ self.llm_model = get_peft_model(self.llm_model, peft_config=lora_config)
50
+
51
+ self.max_context_len = cfg.llm.max_context_len
52
+
53
+ @property
54
+ def device(self):
55
+ return list(self.parameters())[0].device
56
+
57
+ def build_right_justified_sequence(self, data_dict):
58
+ """
59
+ Concat six sequences: `prompt_before_obj`, `prompt_middle_1`, `img_tokens`, `prompt_middle_2`, `obj_tokens`, `prompt_after_obj`.
60
+ Return right justified sequence for causal LM: <pad>, <role/situation>, <img>, <objs>, <instruction>.
61
+ """
62
+ bs = len(data_dict['prompt_before_obj'])
63
+
64
+ self.llm_tokenizer.padding_side = 'left'
65
+ text_input_tokens_pre = self.llm_tokenizer(
66
+ data_dict['prompt_before_obj'],
67
+ return_tensors='pt',
68
+ padding='longest'
69
+ ).to(self.device) # [PAD, BOS, tokens], (B, T1)
70
+
71
+ text_input_tokens_mid1 = self.llm_tokenizer(
72
+ data_dict['prompt_middle_1'],
73
+ return_tensors='pt',
74
+ padding='longest'
75
+ ).to(self.device)
76
+
77
+ img_tokens = data_dict['img_tokens'].to(self.device)
78
+ img_masks = data_dict['img_masks'].to(self.device)
79
+ img_masks = img_masks.reshape(-1, 1).repeat(1, img_tokens.size(1))
80
+
81
+ text_input_tokens_mid2 = self.llm_tokenizer(
82
+ data_dict['prompt_middle_2'],
83
+ return_tensors='pt',
84
+ padding='longest'
85
+ ).to(self.device)
86
+
87
+ obj_tokens = data_dict['obj_tokens'].to(self.device)
88
+ obj_masks = data_dict['obj_masks'].to(self.device)
89
+
90
+ self.llm_tokenizer.padding_side = 'right' # no need to be 'left', as padding tokens will be shifted
91
+ self.llm_tokenizer.truncation_side = 'left' # truncate history
92
+ text_input_tokens_post = self.llm_tokenizer(
93
+ data_dict['prompt_after_obj'],
94
+ return_tensors='pt',
95
+ padding='longest',
96
+ truncation=True,
97
+ max_length=self.max_context_len,
98
+ ).to(self.device) # [BOS, tokens, PAD], (B, T3)
99
+
100
+ # hardcode, remove bos, make "tokenize subseq and concat" equivalent to "tokenize the whole seq"
101
+ assert text_input_tokens_mid1.attention_mask.all() and text_input_tokens_mid2.attention_mask.all(), \
102
+ "prompt_middle should be the same and thus no padding"
103
+
104
+ text_input_tokens_mid1.input_ids = text_input_tokens_mid1.input_ids[:, 1:]
105
+ text_input_tokens_mid1.attention_mask = text_input_tokens_mid1.attention_mask[:, 1:]
106
+ for i in range(bs):
107
+ if not img_masks[i].any():
108
+ # no image input, also mask the text prompt for image tokens
109
+ text_input_tokens_mid1.attention_mask[i].fill_(0)
110
+
111
+ text_input_tokens_mid2.input_ids[:, 0] = 869 # 1 (bos) -> 869 (▁.)
112
+ text_input_tokens_post.input_ids[:, 0] = 869 # 1 (bos) -> 869 (▁.)
113
+
114
+ inputs_embeds_pre = self.llm_model.get_input_embeddings()(text_input_tokens_pre.input_ids)
115
+ inputs_embeds_mid1 = self.llm_model.get_input_embeddings()(text_input_tokens_mid1.input_ids)
116
+ inputs_embeds_mid2 = self.llm_model.get_input_embeddings()(text_input_tokens_mid2.input_ids)
117
+ inputs_embeds_post = self.llm_model.get_input_embeddings()(text_input_tokens_post.input_ids)
118
+
119
+ # since img_tokens, prompt_mid, obj_tokens are fixed length without padding, we concat them first
120
+ inputs_embeds_mid = torch.cat([inputs_embeds_mid1, img_tokens, inputs_embeds_mid2, obj_tokens], dim=1)
121
+ attn_mask_mid = torch.cat([
122
+ text_input_tokens_mid1.attention_mask, img_masks,
123
+ text_input_tokens_mid2.attention_mask, obj_masks
124
+ ], dim=1)
125
+
126
+ post_pad_length = torch.logical_not(text_input_tokens_post.attention_mask).sum(-1)
127
+
128
+ bs, l1, hidden_dim = inputs_embeds_pre.shape
129
+ _, l2, _ = inputs_embeds_mid.shape
130
+ _, l3, _ = inputs_embeds_post.shape
131
+
132
+ inputs_embeds = torch.zeros(
133
+ bs, l1+l2+l3, hidden_dim
134
+ ).type(inputs_embeds_pre.dtype).to(self.device)
135
+
136
+ attention_mask = torch.zeros(
137
+ bs, l1+l2+l3
138
+ ).type(obj_masks.dtype).to(self.device)
139
+
140
+ # assign by chunks
141
+ for i in range(bs):
142
+ post_pad_len = post_pad_length[i]
143
+
144
+ if post_pad_len > 0:
145
+ inputs_embeds[i, :post_pad_len] = inputs_embeds_post[i, -post_pad_len:]
146
+ attention_mask[i, :post_pad_len] = 0
147
+ inputs_embeds[i, post_pad_len+l1+l2:] = inputs_embeds_post[i, :-post_pad_len]
148
+ attention_mask[i, post_pad_len+l1+l2:] = 1
149
+ else:
150
+ # no padding
151
+ inputs_embeds[i, -l3:] = inputs_embeds_post[i]
152
+ attention_mask[i, -l3:] = 1
153
+
154
+ inputs_embeds[i, post_pad_len: post_pad_len+l1] = inputs_embeds_pre[i]
155
+ attention_mask[i, post_pad_len: post_pad_len+l1] = text_input_tokens_pre.attention_mask[i]
156
+
157
+ inputs_embeds[i, post_pad_len+l1: post_pad_len+l1+l2] = inputs_embeds_mid[i]
158
+ attention_mask[i, post_pad_len+l1: post_pad_len+l1+l2] = attn_mask_mid[i]
159
+
160
+ return inputs_embeds, attention_mask
161
+
162
+ @torch.no_grad()
163
+ def generate(
164
+ self,
165
+ data_dict,
166
+ use_nucleus_sampling=False,
167
+ num_beams=5,
168
+ max_length=256,
169
+ min_length=1,
170
+ repetition_penalty=3.0,
171
+ length_penalty=1,
172
+ num_captions=1,
173
+ temperature=1,
174
+ ):
175
+ assert 'img_tokens' in data_dict and 'obj_tokens' in data_dict, "Visual features should have been processed offline."
176
+
177
+ inputs_embeds, attention_mask = self.build_right_justified_sequence(data_dict=data_dict)
178
+ bs = inputs_embeds.shape[0]
179
+
180
+ # give bos token as condition
181
+ bos_tokens = self.llm_tokenizer(
182
+ [self.llm_tokenizer.bos_token] * bs,
183
+ return_tensors='pt',
184
+ ).to(self.device)
185
+ bos_tokens_ids = bos_tokens.input_ids[:, 0:1] # (B, 1)
186
+ bos_tokens_attn = bos_tokens.attention_mask[:, 0:1] # (B, 1)
187
+
188
+ # prepare a `bos_token`
189
+ bos_embeds = self.llm_model.get_input_embeddings()(bos_tokens_ids) # (B, 1, D)
190
+ inputs_embeds = torch.cat([inputs_embeds, bos_embeds], dim=1) # (B, T1+O+T2+1, D)
191
+ attention_mask = torch.cat([attention_mask, bos_tokens_attn], dim=1) # (B, T1+O+T2+1)
192
+
193
+ outputs = self.llm_model.generate(
194
+ inputs_embeds=inputs_embeds,
195
+ attention_mask=attention_mask,
196
+ do_sample=use_nucleus_sampling,
197
+ temperature=temperature,
198
+ num_beams=num_beams,
199
+ max_length=max_length,
200
+ min_length=min_length,
201
+ repetition_penalty=repetition_penalty,
202
+ length_penalty=length_penalty,
203
+ num_return_sequences=num_captions,
204
+ )
205
+
206
+ outputs[outputs == 0] = 2 # convert output id 0 (unk_token) to 2 (eos_token)
207
+
208
+ output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
209
+ output_text = [text.strip() for text in output_text]
210
+ return output_text
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu116
2
+ omegaconf==2.3.0
3
+ peft==0.5.0
4
+ pyyaml==6.0.1
5
+ sentencepiece
6
+ torch==1.13.0+cu116
7
+ transformers==4.28.1
utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import yaml
10
+ from huggingface_hub import hf_hub_download
11
+ from omegaconf import OmegaConf
12
+
13
+ from model.leo_agent import LeoAgentLLM
14
+
15
+ LOG_DIR = 'logs'
16
+ MESH_DIR = 'assets/scene_meshes'
17
+ MESH_NAMES = [os.path.splitext(fname)[0] for fname in os.listdir(MESH_DIR)]
18
+ ENABLE_BUTTON = gr.update(interactive=True)
19
+ DISABLE_BUTTON = gr.update(interactive=False)
20
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+
22
+ ROLE_PROMPT = "You are an AI visual assistant situated in a 3D scene. "\
23
+ "You can perceive (1) an ego-view image (accessible when necessary) and (2) the objects (including yourself) in the scene (always accessible). "\
24
+ "You should properly respond to the USER's instruction according to the given visual information. "
25
+ EGOVIEW_PROMPT = "Ego-view image:"
26
+ OBJECTS_PROMPT = "Objects (including you) in the scene:"
27
+ TASK_PROMPT = "USER: {instruction} ASSISTANT:"
28
+ OBJ_FEATS_DIR = 'assets/obj_features'
29
+
30
+
31
+ def load_agent():
32
+ # build model
33
+ with open('model/cfg.yaml') as f:
34
+ cfg = yaml.safe_load(f)
35
+ cfg = OmegaConf.create(cfg)
36
+ agent = LeoAgentLLM(cfg)
37
+
38
+ # load checkpoint
39
+ if cfg.use_ckpt == 'hf':
40
+ ckpt_path = hf_hub_download(cfg.hf_ckpt_path[0], cfg.hf_ckpt_path[1])
41
+ else:
42
+ ckpt_path = cfg.local_ckpt_path
43
+ ckpt = torch.load(ckpt_path, map_location='cpu')
44
+ agent.load_state_dict(ckpt, strict=False)
45
+
46
+ agent.eval()
47
+ agent.to(DEVICE)
48
+ return agent
49
+
50
+ agent = load_agent()
51
+
52
+
53
+ def get_log_fname():
54
+ t = datetime.datetime.now()
55
+ fname = os.path.join(LOG_DIR, f'{t.year}-{t.month:02d}-{t.day:02d}.json')
56
+ return fname
57
+
58
+
59
+ def change_scene(dropdown_scene: str):
60
+ # reset 3D scene and chatbot history
61
+ return os.path.join(MESH_DIR, f'{dropdown_scene}.glb'), None
62
+
63
+
64
+ def receive_instruction(chatbot: gr.Chatbot, user_chat_input: gr.Textbox):
65
+ # display user input, after submitting user message, before inference
66
+ chatbot.append((user_chat_input, None))
67
+ return (chatbot, gr.update(value=""),) + (DISABLE_BUTTON,) * 5
68
+
69
+
70
+ def generate_response(
71
+ chatbot: gr.Chatbot,
72
+ dropdown_scene: gr.Dropdown,
73
+ dropdown_conversation_mode: gr.Dropdown,
74
+ repetition_penalty: float, length_penalty: float
75
+ ):
76
+ # response starts
77
+ chatbot[-1] = (chatbot[-1][0], "▌")
78
+ yield (chatbot,) + (DISABLE_BUTTON,) * 5
79
+
80
+ # create data_dict, batch_size = 1
81
+ data_dict = {
82
+ 'prompt_before_obj': [ROLE_PROMPT],
83
+ 'prompt_middle_1': [EGOVIEW_PROMPT],
84
+ 'prompt_middle_2': [OBJECTS_PROMPT],
85
+ 'img_tokens': torch.zeros(1, 1, 4096).float(),
86
+ 'img_masks': torch.zeros(1, 1).bool(),
87
+ 'anchor_locs': torch.zeros(1, 3).float(),
88
+ }
89
+
90
+ # initialize prompt
91
+ prompt = ""
92
+ if 'Multi-round' in dropdown_conversation_mode:
93
+ # multi-round dialogue, with memory
94
+ for (q, a) in chatbot[:-1]:
95
+ prompt += f"USER: {q.strip()} ASSISTANT: {a.strip()}</s>"
96
+
97
+ prompt += f"USER: {chatbot[-1][0]} ASSISTANT:"
98
+ data_dict['prompt_after_obj'] = [prompt]
99
+
100
+ # anchor orientation
101
+ anchor_orient = torch.zeros(1, 4).float()
102
+ anchor_orient[:, -1] = 1
103
+ data_dict['anchor_orientation'] = anchor_orient
104
+
105
+ # load preprocessed scene features
106
+ data_dict.update(torch.load(os.path.join(OBJ_FEATS_DIR, f'{dropdown_scene}.pth'), map_location='cpu'))
107
+
108
+ # inference
109
+ for k, v in data_dict.items():
110
+ if isinstance(v, torch.Tensor):
111
+ data_dict[k] = v.to(DEVICE)
112
+
113
+ output = agent.generate(
114
+ data_dict,
115
+ repetition_penalty=float(repetition_penalty),
116
+ length_penalty=float(length_penalty),
117
+ )
118
+ output = output[0]
119
+
120
+ # display response
121
+ for out_len in range(1, len(output)-1):
122
+ chatbot[-1] = (chatbot[-1][0], output[:out_len] + '▌')
123
+ yield (chatbot,) + (DISABLE_BUTTON,) * 5
124
+ time.sleep(0.01)
125
+
126
+ chatbot[-1] = (chatbot[-1][0], output)
127
+ vote_response(chatbot, 'log', dropdown_scene, dropdown_conversation_mode)
128
+ yield (chatbot,) + (ENABLE_BUTTON,) * 5
129
+
130
+
131
+ def vote_response(
132
+ chatbot: gr.Chatbot, vote_type: str,
133
+ dropdown_scene: gr.Dropdown,
134
+ dropdown_conversation_mode: gr.Dropdown
135
+ ):
136
+ t = datetime.datetime.now()
137
+ this_log = {
138
+ 'time': f'{t.hour:02d}:{t.minute:02d}:{t.second:02d}',
139
+ 'type': vote_type,
140
+ 'scene': dropdown_scene,
141
+ 'mode': dropdown_conversation_mode,
142
+ 'dialogue': chatbot,
143
+ }
144
+ fname = get_log_fname()
145
+ if os.path.exists(fname):
146
+ with open(fname) as f:
147
+ logs = json.load(f)
148
+ logs.append(this_log)
149
+ else:
150
+ logs = [this_log]
151
+ with open(fname, 'w') as f:
152
+ json.dump(logs, f, indent=2)
153
+
154
+
155
+ def upvote_response(
156
+ chatbot: gr.Chatbot,
157
+ dropdown_scene: gr.Dropdown,
158
+ dropdown_conversation_mode: gr.Dropdown
159
+ ):
160
+ vote_response(chatbot, 'upvote', dropdown_scene, dropdown_conversation_mode)
161
+ return ("",) + (DISABLE_BUTTON,) * 3
162
+
163
+
164
+ def downvote_response(
165
+ chatbot: gr.Chatbot,
166
+ dropdown_scene: gr.Dropdown,
167
+ dropdown_conversation_mode: gr.Dropdown
168
+ ):
169
+ vote_response(chatbot, 'downvote', dropdown_scene, dropdown_conversation_mode)
170
+ return ("",) + (DISABLE_BUTTON,) * 3
171
+
172
+
173
+ def flag_response(
174
+ chatbot: gr.Chatbot,
175
+ dropdown_scene: gr.Dropdown,
176
+ dropdown_conversation_mode: gr.Dropdown
177
+ ):
178
+ vote_response(chatbot, 'flag', dropdown_scene, dropdown_conversation_mode)
179
+ return ("",) + (DISABLE_BUTTON,) * 3
180
+
181
+
182
+ def clear_history():
183
+ # reset chatbot history
184
+ return (None, "",) + (DISABLE_BUTTON,) * 4