Spaces:
Runtime error
Runtime error
Commit
·
7978a78
0
Parent(s):
init
Browse files- .gitattributes +36 -0
- .gitignore +2 -0
- README.md +13 -0
- app.py +181 -0
- assets/leo.svg +0 -0
- assets/obj_features/3RScan-0cac759b-8d6f-2d13-8e3b-2e3bc1ee1158.pth +3 -0
- assets/obj_features/3RScan-0cac760d-8d6f-2d13-8ea2-109ce4da9ac9.pth +3 -0
- assets/obj_features/3RScan-752cc597-920c-26f5-8c1b-a8a5c90a21d7.pth +3 -0
- assets/scene_meshes/3RScan-0cac759b-8d6f-2d13-8e3b-2e3bc1ee1158.glb +3 -0
- assets/scene_meshes/3RScan-0cac760d-8d6f-2d13-8ea2-109ce4da9ac9.glb +3 -0
- assets/scene_meshes/3RScan-752cc597-920c-26f5-8c1b-a8a5c90a21d7.glb +3 -0
- model/cfg.yaml +21 -0
- model/leo_agent.py +210 -0
- requirements.txt +7 -0
- utils.py +184 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.glb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
logs/
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: LEO
|
3 |
+
emoji: 🦁
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.10.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from utils import *
|
6 |
+
|
7 |
+
|
8 |
+
with gr.Blocks(title='LEO Demo') as demo:
|
9 |
+
gr.HTML(value="<h1 align='center'>An Embodied Generalist Agent in 3D World</h1>")
|
10 |
+
gr.HTML(value="<div align='center' style='margin-top:-1em; margin-bottom:-1em;'><img src='/file=assets/leo.svg' width='4%'></div>")
|
11 |
+
# gr.HTML(value="<img src='/file=assets/teaser.png' alt='Teaser' width='760px' style='display: block; margin: auto;'>")
|
12 |
+
gr.HTML(value="<p align='center' style='font-size: 1.2em; color: #485fc7;'><a href='https://arxiv.org/abs/2311.12871' target='_blank'>arXiv</a> | <a href='https://embodied-generalist.github.io/' target='_blank'>Project Page</a> | <a href='https://github.com/embodied-generalist/embodied-generalist' target='_blank'>Code</a></p>")
|
13 |
+
gr.HTML(value="<p align='center' style='font-size: 1.15em;'><i>LEO: an embodied generalist agent capable of perceiving, grounding, reasoning, planning, and acting in 3D world.</i></p>")
|
14 |
+
|
15 |
+
with gr.Row():
|
16 |
+
with gr.Column(scale=5):
|
17 |
+
dropdown_scene = gr.Dropdown(
|
18 |
+
choices=MESH_NAMES,
|
19 |
+
value=MESH_NAMES[0],
|
20 |
+
interactive=True,
|
21 |
+
label='Select a 3D scene',
|
22 |
+
)
|
23 |
+
model_3d = gr.Model3D(
|
24 |
+
value=os.path.join(MESH_DIR, f'{MESH_NAMES[0]}.glb'),
|
25 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
26 |
+
label='3D Scene',
|
27 |
+
camera_position=(90, 30, 10),
|
28 |
+
height=659,
|
29 |
+
)
|
30 |
+
gr.HTML(
|
31 |
+
"""<center><strong>
|
32 |
+
👆 SCROLL and DRAG on the 3D Scene
|
33 |
+
to zoom in/out and rotate. Press CTRL and DRAG to pan.
|
34 |
+
</strong></center>
|
35 |
+
"""
|
36 |
+
)
|
37 |
+
with gr.Column(scale=5):
|
38 |
+
dropdown_conversation_mode = gr.Dropdown(
|
39 |
+
choices=['Single-round mode', 'Multi-round mode'],
|
40 |
+
value='Single-round mode',
|
41 |
+
interactive=True,
|
42 |
+
label='Select conversation mode',
|
43 |
+
)
|
44 |
+
chatbot = gr.Chatbot(label='Chat with LEO')
|
45 |
+
with gr.Row():
|
46 |
+
with gr.Column(scale=8):
|
47 |
+
user_chat_input = gr.Textbox(
|
48 |
+
placeholder="Enter text here to chat with LEO",
|
49 |
+
show_label=False,
|
50 |
+
autofocus=True,
|
51 |
+
)
|
52 |
+
with gr.Column(scale=2, min_width=0):
|
53 |
+
send_button = gr.Button('Send', variant='primary', scale=2)
|
54 |
+
with gr.Row():
|
55 |
+
upvote_button = gr.Button(value='👍 Upvote', interactive=False)
|
56 |
+
downvote_button = gr.Button(value='👎 Downvote', interactive=False)
|
57 |
+
flag_button = gr.Button(value='⚠️ Flag', interactive=False)
|
58 |
+
clear_button = gr.Button(value='🗑️ Clear', interactive=False)
|
59 |
+
with gr.Row():
|
60 |
+
with gr.Accordion(label="Examples for user instruction:", open=True):
|
61 |
+
gr.Examples(
|
62 |
+
examples=[
|
63 |
+
["How many armchairs are there in this room?"],
|
64 |
+
["Is there a radio in the room?"],
|
65 |
+
["Where is the wardrobe located?TODO"],
|
66 |
+
["What is the shape of the shelf in front of the picture?TODO"],
|
67 |
+
["Plan for the task: Tidy up and arrange the nursery room.TODO"],
|
68 |
+
],
|
69 |
+
inputs=user_chat_input,
|
70 |
+
)
|
71 |
+
|
72 |
+
# generation_config
|
73 |
+
with gr.Accordion('Parameters', open=False):
|
74 |
+
repetition_penalty = gr.Slider(
|
75 |
+
minimum=0.0,
|
76 |
+
maximum=10.0,
|
77 |
+
value=3.0,
|
78 |
+
step=1.0,
|
79 |
+
interactive=True,
|
80 |
+
label='Repetition penalty',
|
81 |
+
)
|
82 |
+
length_penalty = gr.Slider(
|
83 |
+
minimum=0.0,
|
84 |
+
maximum=10.0,
|
85 |
+
value=1.0,
|
86 |
+
step=1.0,
|
87 |
+
interactive=True,
|
88 |
+
label="Length penalty",
|
89 |
+
)
|
90 |
+
gr.Markdown("### Terms of Service")
|
91 |
+
gr.HTML(
|
92 |
+
"""By using this service, users are required to agree to the following terms:
|
93 |
+
the service is a research preview intended for non-commercial use only
|
94 |
+
and may collect user dialogue data for future research."""
|
95 |
+
)
|
96 |
+
gr.Markdown("### Acknowledgment")
|
97 |
+
gr.HTML(
|
98 |
+
"""Template adapted from <a href="https://llava.hliu.cc/">LLaVA</a> and
|
99 |
+
<a href="http://sled-whistler.eecs.umich.edu:7777/">LLM-Grounder</a>."""
|
100 |
+
)
|
101 |
+
|
102 |
+
# Event handling
|
103 |
+
button_list = [upvote_button, downvote_button, flag_button, clear_button]
|
104 |
+
|
105 |
+
dropdown_scene.change(
|
106 |
+
fn=change_scene,
|
107 |
+
inputs=[dropdown_scene],
|
108 |
+
outputs=[model_3d, chatbot],
|
109 |
+
queue=False,
|
110 |
+
)
|
111 |
+
|
112 |
+
dropdown_conversation_mode.change(
|
113 |
+
fn=clear_history,
|
114 |
+
inputs=[],
|
115 |
+
outputs=[chatbot, user_chat_input] + button_list,
|
116 |
+
queue=False,
|
117 |
+
)
|
118 |
+
|
119 |
+
user_chat_input.submit(
|
120 |
+
fn=receive_instruction,
|
121 |
+
inputs=[chatbot, user_chat_input],
|
122 |
+
outputs=[chatbot, user_chat_input, send_button] + button_list,
|
123 |
+
queue=False,
|
124 |
+
).then(
|
125 |
+
fn=generate_response,
|
126 |
+
inputs=[
|
127 |
+
chatbot,
|
128 |
+
dropdown_scene,
|
129 |
+
dropdown_conversation_mode,
|
130 |
+
repetition_penalty,
|
131 |
+
length_penalty,
|
132 |
+
],
|
133 |
+
outputs=[chatbot, send_button] + button_list,
|
134 |
+
scroll_to_output=True,
|
135 |
+
)
|
136 |
+
|
137 |
+
send_button.click(
|
138 |
+
fn=receive_instruction,
|
139 |
+
inputs=[chatbot, user_chat_input],
|
140 |
+
outputs=[chatbot, user_chat_input, send_button] + button_list,
|
141 |
+
queue=False,
|
142 |
+
).then(
|
143 |
+
fn=generate_response,
|
144 |
+
inputs=[
|
145 |
+
chatbot,
|
146 |
+
dropdown_scene,
|
147 |
+
dropdown_conversation_mode,
|
148 |
+
repetition_penalty,
|
149 |
+
length_penalty,
|
150 |
+
],
|
151 |
+
outputs=[chatbot, send_button] + button_list,
|
152 |
+
scroll_to_output=True,
|
153 |
+
)
|
154 |
+
|
155 |
+
upvote_button.click(
|
156 |
+
upvote_response,
|
157 |
+
[chatbot, dropdown_scene, dropdown_conversation_mode],
|
158 |
+
[user_chat_input, upvote_button, downvote_button, flag_button],
|
159 |
+
queue=False,
|
160 |
+
)
|
161 |
+
downvote_button.click(
|
162 |
+
downvote_response,
|
163 |
+
[chatbot, dropdown_scene, dropdown_conversation_mode],
|
164 |
+
[user_chat_input, upvote_button, downvote_button, flag_button],
|
165 |
+
queue=False,
|
166 |
+
)
|
167 |
+
flag_button.click(
|
168 |
+
flag_response,
|
169 |
+
[chatbot, dropdown_scene, dropdown_conversation_mode],
|
170 |
+
[user_chat_input, upvote_button, downvote_button, flag_button],
|
171 |
+
queue=False,
|
172 |
+
)
|
173 |
+
clear_button.click(
|
174 |
+
fn=clear_history,
|
175 |
+
inputs=[],
|
176 |
+
outputs=[chatbot, user_chat_input] + button_list,
|
177 |
+
queue=False,
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
demo.queue().launch(share=True, allowed_paths=['assets'])
|
assets/leo.svg
ADDED
|
assets/obj_features/3RScan-0cac759b-8d6f-2d13-8e3b-2e3bc1ee1158.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5642bb84ba04d10c5aa199dbcd5ea1ab01df0d2517719a2a2e943381f11bd25b
|
3 |
+
size 1002083
|
assets/obj_features/3RScan-0cac760d-8d6f-2d13-8ea2-109ce4da9ac9.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eae2324173b34331b6dad37c89a75db275d1d23fbb1f1d7478573085cdf1d733
|
3 |
+
size 1002083
|
assets/obj_features/3RScan-752cc597-920c-26f5-8c1b-a8a5c90a21d7.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50a9e124ea270cbe23b59fbddb527d5cf61005c657bd3f5f41535998ba84d9b6
|
3 |
+
size 1002083
|
assets/scene_meshes/3RScan-0cac759b-8d6f-2d13-8e3b-2e3bc1ee1158.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d197483b3be1f6f1395faa3a8b413ee23335fd8f081456b63db96f5928291b1
|
3 |
+
size 9632176
|
assets/scene_meshes/3RScan-0cac760d-8d6f-2d13-8ea2-109ce4da9ac9.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:419988aa4781ec7d0a06e9087c8a918a20c389c50b210daa6b3c47be981b28ac
|
3 |
+
size 9445868
|
assets/scene_meshes/3RScan-752cc597-920c-26f5-8c1b-a8a5c90a21d7.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0db74afa2648056c839840ba8a11d832012b6f70114668835c2da82d5ae07ec2
|
3 |
+
size 11326324
|
model/cfg.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use_ckpt: hf
|
2 |
+
hf_ckpt_path: [huangjy-pku/embodied-generalist, weights/leo_noact_hf.pth]
|
3 |
+
local_ckpt_path: /mnt/huangjiangyong/leo/hf_assets/weights/leo_noact_lora.pth
|
4 |
+
model:
|
5 |
+
name: LeoAgentLLM
|
6 |
+
# vision modules omitted
|
7 |
+
llm:
|
8 |
+
name: Vicuna7B
|
9 |
+
use_ckpt: hf
|
10 |
+
hf_cfg_path: huangjy-pku/vicuna-7b
|
11 |
+
local_cfg_path: /mnt/huangjiangyong/vicuna-7b
|
12 |
+
truncation_side: right
|
13 |
+
prompt: ""
|
14 |
+
max_out_len: 256
|
15 |
+
max_context_len: 256 # for prompt_after_obj
|
16 |
+
lora:
|
17 |
+
flag: True
|
18 |
+
rank: 16
|
19 |
+
alpha: 16
|
20 |
+
target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
|
21 |
+
dropout: 0.0
|
model/leo_agent.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
from peft import get_peft_model, LoraConfig
|
5 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
6 |
+
|
7 |
+
|
8 |
+
def disabled_train(self, mode=True):
|
9 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
10 |
+
does not change anymore."""
|
11 |
+
return self
|
12 |
+
|
13 |
+
|
14 |
+
class LeoAgentLLM(nn.Module):
|
15 |
+
def __init__(self, cfg):
|
16 |
+
super().__init__()
|
17 |
+
if hasattr(cfg, 'model'):
|
18 |
+
cfg = cfg.model
|
19 |
+
|
20 |
+
# LLM
|
21 |
+
if cfg.llm.use_ckpt == 'hf':
|
22 |
+
llm_cfg_path = snapshot_download(cfg.llm.hf_cfg_path)
|
23 |
+
else:
|
24 |
+
llm_cfg_path = cfg.llm.local_cfg_path
|
25 |
+
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, use_fast=False,
|
26 |
+
truncation_side=cfg.llm.truncation_side)
|
27 |
+
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
28 |
+
self.llm_tokenizer.add_special_tokens({'bos_token': '<s>'})
|
29 |
+
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
|
30 |
+
self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
|
31 |
+
self.llm_model = LlamaForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16)
|
32 |
+
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
|
33 |
+
|
34 |
+
for param in self.llm_model.parameters():
|
35 |
+
param.requires_grad = False
|
36 |
+
self.llm_model.eval()
|
37 |
+
self.llm_model.train = disabled_train
|
38 |
+
|
39 |
+
# LoRA-based LLM fine-tuning
|
40 |
+
if cfg.llm.lora.flag:
|
41 |
+
lora_config = LoraConfig(
|
42 |
+
r=cfg.llm.lora.rank,
|
43 |
+
lora_alpha=cfg.llm.lora.alpha,
|
44 |
+
target_modules=cfg.llm.lora.target_modules,
|
45 |
+
lora_dropout=cfg.llm.lora.dropout,
|
46 |
+
bias='none',
|
47 |
+
modules_to_save=[],
|
48 |
+
)
|
49 |
+
self.llm_model = get_peft_model(self.llm_model, peft_config=lora_config)
|
50 |
+
|
51 |
+
self.max_context_len = cfg.llm.max_context_len
|
52 |
+
|
53 |
+
@property
|
54 |
+
def device(self):
|
55 |
+
return list(self.parameters())[0].device
|
56 |
+
|
57 |
+
def build_right_justified_sequence(self, data_dict):
|
58 |
+
"""
|
59 |
+
Concat six sequences: `prompt_before_obj`, `prompt_middle_1`, `img_tokens`, `prompt_middle_2`, `obj_tokens`, `prompt_after_obj`.
|
60 |
+
Return right justified sequence for causal LM: <pad>, <role/situation>, <img>, <objs>, <instruction>.
|
61 |
+
"""
|
62 |
+
bs = len(data_dict['prompt_before_obj'])
|
63 |
+
|
64 |
+
self.llm_tokenizer.padding_side = 'left'
|
65 |
+
text_input_tokens_pre = self.llm_tokenizer(
|
66 |
+
data_dict['prompt_before_obj'],
|
67 |
+
return_tensors='pt',
|
68 |
+
padding='longest'
|
69 |
+
).to(self.device) # [PAD, BOS, tokens], (B, T1)
|
70 |
+
|
71 |
+
text_input_tokens_mid1 = self.llm_tokenizer(
|
72 |
+
data_dict['prompt_middle_1'],
|
73 |
+
return_tensors='pt',
|
74 |
+
padding='longest'
|
75 |
+
).to(self.device)
|
76 |
+
|
77 |
+
img_tokens = data_dict['img_tokens'].to(self.device)
|
78 |
+
img_masks = data_dict['img_masks'].to(self.device)
|
79 |
+
img_masks = img_masks.reshape(-1, 1).repeat(1, img_tokens.size(1))
|
80 |
+
|
81 |
+
text_input_tokens_mid2 = self.llm_tokenizer(
|
82 |
+
data_dict['prompt_middle_2'],
|
83 |
+
return_tensors='pt',
|
84 |
+
padding='longest'
|
85 |
+
).to(self.device)
|
86 |
+
|
87 |
+
obj_tokens = data_dict['obj_tokens'].to(self.device)
|
88 |
+
obj_masks = data_dict['obj_masks'].to(self.device)
|
89 |
+
|
90 |
+
self.llm_tokenizer.padding_side = 'right' # no need to be 'left', as padding tokens will be shifted
|
91 |
+
self.llm_tokenizer.truncation_side = 'left' # truncate history
|
92 |
+
text_input_tokens_post = self.llm_tokenizer(
|
93 |
+
data_dict['prompt_after_obj'],
|
94 |
+
return_tensors='pt',
|
95 |
+
padding='longest',
|
96 |
+
truncation=True,
|
97 |
+
max_length=self.max_context_len,
|
98 |
+
).to(self.device) # [BOS, tokens, PAD], (B, T3)
|
99 |
+
|
100 |
+
# hardcode, remove bos, make "tokenize subseq and concat" equivalent to "tokenize the whole seq"
|
101 |
+
assert text_input_tokens_mid1.attention_mask.all() and text_input_tokens_mid2.attention_mask.all(), \
|
102 |
+
"prompt_middle should be the same and thus no padding"
|
103 |
+
|
104 |
+
text_input_tokens_mid1.input_ids = text_input_tokens_mid1.input_ids[:, 1:]
|
105 |
+
text_input_tokens_mid1.attention_mask = text_input_tokens_mid1.attention_mask[:, 1:]
|
106 |
+
for i in range(bs):
|
107 |
+
if not img_masks[i].any():
|
108 |
+
# no image input, also mask the text prompt for image tokens
|
109 |
+
text_input_tokens_mid1.attention_mask[i].fill_(0)
|
110 |
+
|
111 |
+
text_input_tokens_mid2.input_ids[:, 0] = 869 # 1 (bos) -> 869 (▁.)
|
112 |
+
text_input_tokens_post.input_ids[:, 0] = 869 # 1 (bos) -> 869 (▁.)
|
113 |
+
|
114 |
+
inputs_embeds_pre = self.llm_model.get_input_embeddings()(text_input_tokens_pre.input_ids)
|
115 |
+
inputs_embeds_mid1 = self.llm_model.get_input_embeddings()(text_input_tokens_mid1.input_ids)
|
116 |
+
inputs_embeds_mid2 = self.llm_model.get_input_embeddings()(text_input_tokens_mid2.input_ids)
|
117 |
+
inputs_embeds_post = self.llm_model.get_input_embeddings()(text_input_tokens_post.input_ids)
|
118 |
+
|
119 |
+
# since img_tokens, prompt_mid, obj_tokens are fixed length without padding, we concat them first
|
120 |
+
inputs_embeds_mid = torch.cat([inputs_embeds_mid1, img_tokens, inputs_embeds_mid2, obj_tokens], dim=1)
|
121 |
+
attn_mask_mid = torch.cat([
|
122 |
+
text_input_tokens_mid1.attention_mask, img_masks,
|
123 |
+
text_input_tokens_mid2.attention_mask, obj_masks
|
124 |
+
], dim=1)
|
125 |
+
|
126 |
+
post_pad_length = torch.logical_not(text_input_tokens_post.attention_mask).sum(-1)
|
127 |
+
|
128 |
+
bs, l1, hidden_dim = inputs_embeds_pre.shape
|
129 |
+
_, l2, _ = inputs_embeds_mid.shape
|
130 |
+
_, l3, _ = inputs_embeds_post.shape
|
131 |
+
|
132 |
+
inputs_embeds = torch.zeros(
|
133 |
+
bs, l1+l2+l3, hidden_dim
|
134 |
+
).type(inputs_embeds_pre.dtype).to(self.device)
|
135 |
+
|
136 |
+
attention_mask = torch.zeros(
|
137 |
+
bs, l1+l2+l3
|
138 |
+
).type(obj_masks.dtype).to(self.device)
|
139 |
+
|
140 |
+
# assign by chunks
|
141 |
+
for i in range(bs):
|
142 |
+
post_pad_len = post_pad_length[i]
|
143 |
+
|
144 |
+
if post_pad_len > 0:
|
145 |
+
inputs_embeds[i, :post_pad_len] = inputs_embeds_post[i, -post_pad_len:]
|
146 |
+
attention_mask[i, :post_pad_len] = 0
|
147 |
+
inputs_embeds[i, post_pad_len+l1+l2:] = inputs_embeds_post[i, :-post_pad_len]
|
148 |
+
attention_mask[i, post_pad_len+l1+l2:] = 1
|
149 |
+
else:
|
150 |
+
# no padding
|
151 |
+
inputs_embeds[i, -l3:] = inputs_embeds_post[i]
|
152 |
+
attention_mask[i, -l3:] = 1
|
153 |
+
|
154 |
+
inputs_embeds[i, post_pad_len: post_pad_len+l1] = inputs_embeds_pre[i]
|
155 |
+
attention_mask[i, post_pad_len: post_pad_len+l1] = text_input_tokens_pre.attention_mask[i]
|
156 |
+
|
157 |
+
inputs_embeds[i, post_pad_len+l1: post_pad_len+l1+l2] = inputs_embeds_mid[i]
|
158 |
+
attention_mask[i, post_pad_len+l1: post_pad_len+l1+l2] = attn_mask_mid[i]
|
159 |
+
|
160 |
+
return inputs_embeds, attention_mask
|
161 |
+
|
162 |
+
@torch.no_grad()
|
163 |
+
def generate(
|
164 |
+
self,
|
165 |
+
data_dict,
|
166 |
+
use_nucleus_sampling=False,
|
167 |
+
num_beams=5,
|
168 |
+
max_length=256,
|
169 |
+
min_length=1,
|
170 |
+
repetition_penalty=3.0,
|
171 |
+
length_penalty=1,
|
172 |
+
num_captions=1,
|
173 |
+
temperature=1,
|
174 |
+
):
|
175 |
+
assert 'img_tokens' in data_dict and 'obj_tokens' in data_dict, "Visual features should have been processed offline."
|
176 |
+
|
177 |
+
inputs_embeds, attention_mask = self.build_right_justified_sequence(data_dict=data_dict)
|
178 |
+
bs = inputs_embeds.shape[0]
|
179 |
+
|
180 |
+
# give bos token as condition
|
181 |
+
bos_tokens = self.llm_tokenizer(
|
182 |
+
[self.llm_tokenizer.bos_token] * bs,
|
183 |
+
return_tensors='pt',
|
184 |
+
).to(self.device)
|
185 |
+
bos_tokens_ids = bos_tokens.input_ids[:, 0:1] # (B, 1)
|
186 |
+
bos_tokens_attn = bos_tokens.attention_mask[:, 0:1] # (B, 1)
|
187 |
+
|
188 |
+
# prepare a `bos_token`
|
189 |
+
bos_embeds = self.llm_model.get_input_embeddings()(bos_tokens_ids) # (B, 1, D)
|
190 |
+
inputs_embeds = torch.cat([inputs_embeds, bos_embeds], dim=1) # (B, T1+O+T2+1, D)
|
191 |
+
attention_mask = torch.cat([attention_mask, bos_tokens_attn], dim=1) # (B, T1+O+T2+1)
|
192 |
+
|
193 |
+
outputs = self.llm_model.generate(
|
194 |
+
inputs_embeds=inputs_embeds,
|
195 |
+
attention_mask=attention_mask,
|
196 |
+
do_sample=use_nucleus_sampling,
|
197 |
+
temperature=temperature,
|
198 |
+
num_beams=num_beams,
|
199 |
+
max_length=max_length,
|
200 |
+
min_length=min_length,
|
201 |
+
repetition_penalty=repetition_penalty,
|
202 |
+
length_penalty=length_penalty,
|
203 |
+
num_return_sequences=num_captions,
|
204 |
+
)
|
205 |
+
|
206 |
+
outputs[outputs == 0] = 2 # convert output id 0 (unk_token) to 2 (eos_token)
|
207 |
+
|
208 |
+
output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
209 |
+
output_text = [text.strip() for text in output_text]
|
210 |
+
return output_text
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|
2 |
+
omegaconf==2.3.0
|
3 |
+
peft==0.5.0
|
4 |
+
pyyaml==6.0.1
|
5 |
+
sentencepiece
|
6 |
+
torch==1.13.0+cu116
|
7 |
+
transformers==4.28.1
|
utils.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
import yaml
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
|
13 |
+
from model.leo_agent import LeoAgentLLM
|
14 |
+
|
15 |
+
LOG_DIR = 'logs'
|
16 |
+
MESH_DIR = 'assets/scene_meshes'
|
17 |
+
MESH_NAMES = [os.path.splitext(fname)[0] for fname in os.listdir(MESH_DIR)]
|
18 |
+
ENABLE_BUTTON = gr.update(interactive=True)
|
19 |
+
DISABLE_BUTTON = gr.update(interactive=False)
|
20 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
|
22 |
+
ROLE_PROMPT = "You are an AI visual assistant situated in a 3D scene. "\
|
23 |
+
"You can perceive (1) an ego-view image (accessible when necessary) and (2) the objects (including yourself) in the scene (always accessible). "\
|
24 |
+
"You should properly respond to the USER's instruction according to the given visual information. "
|
25 |
+
EGOVIEW_PROMPT = "Ego-view image:"
|
26 |
+
OBJECTS_PROMPT = "Objects (including you) in the scene:"
|
27 |
+
TASK_PROMPT = "USER: {instruction} ASSISTANT:"
|
28 |
+
OBJ_FEATS_DIR = 'assets/obj_features'
|
29 |
+
|
30 |
+
|
31 |
+
def load_agent():
|
32 |
+
# build model
|
33 |
+
with open('model/cfg.yaml') as f:
|
34 |
+
cfg = yaml.safe_load(f)
|
35 |
+
cfg = OmegaConf.create(cfg)
|
36 |
+
agent = LeoAgentLLM(cfg)
|
37 |
+
|
38 |
+
# load checkpoint
|
39 |
+
if cfg.use_ckpt == 'hf':
|
40 |
+
ckpt_path = hf_hub_download(cfg.hf_ckpt_path[0], cfg.hf_ckpt_path[1])
|
41 |
+
else:
|
42 |
+
ckpt_path = cfg.local_ckpt_path
|
43 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
44 |
+
agent.load_state_dict(ckpt, strict=False)
|
45 |
+
|
46 |
+
agent.eval()
|
47 |
+
agent.to(DEVICE)
|
48 |
+
return agent
|
49 |
+
|
50 |
+
agent = load_agent()
|
51 |
+
|
52 |
+
|
53 |
+
def get_log_fname():
|
54 |
+
t = datetime.datetime.now()
|
55 |
+
fname = os.path.join(LOG_DIR, f'{t.year}-{t.month:02d}-{t.day:02d}.json')
|
56 |
+
return fname
|
57 |
+
|
58 |
+
|
59 |
+
def change_scene(dropdown_scene: str):
|
60 |
+
# reset 3D scene and chatbot history
|
61 |
+
return os.path.join(MESH_DIR, f'{dropdown_scene}.glb'), None
|
62 |
+
|
63 |
+
|
64 |
+
def receive_instruction(chatbot: gr.Chatbot, user_chat_input: gr.Textbox):
|
65 |
+
# display user input, after submitting user message, before inference
|
66 |
+
chatbot.append((user_chat_input, None))
|
67 |
+
return (chatbot, gr.update(value=""),) + (DISABLE_BUTTON,) * 5
|
68 |
+
|
69 |
+
|
70 |
+
def generate_response(
|
71 |
+
chatbot: gr.Chatbot,
|
72 |
+
dropdown_scene: gr.Dropdown,
|
73 |
+
dropdown_conversation_mode: gr.Dropdown,
|
74 |
+
repetition_penalty: float, length_penalty: float
|
75 |
+
):
|
76 |
+
# response starts
|
77 |
+
chatbot[-1] = (chatbot[-1][0], "▌")
|
78 |
+
yield (chatbot,) + (DISABLE_BUTTON,) * 5
|
79 |
+
|
80 |
+
# create data_dict, batch_size = 1
|
81 |
+
data_dict = {
|
82 |
+
'prompt_before_obj': [ROLE_PROMPT],
|
83 |
+
'prompt_middle_1': [EGOVIEW_PROMPT],
|
84 |
+
'prompt_middle_2': [OBJECTS_PROMPT],
|
85 |
+
'img_tokens': torch.zeros(1, 1, 4096).float(),
|
86 |
+
'img_masks': torch.zeros(1, 1).bool(),
|
87 |
+
'anchor_locs': torch.zeros(1, 3).float(),
|
88 |
+
}
|
89 |
+
|
90 |
+
# initialize prompt
|
91 |
+
prompt = ""
|
92 |
+
if 'Multi-round' in dropdown_conversation_mode:
|
93 |
+
# multi-round dialogue, with memory
|
94 |
+
for (q, a) in chatbot[:-1]:
|
95 |
+
prompt += f"USER: {q.strip()} ASSISTANT: {a.strip()}</s>"
|
96 |
+
|
97 |
+
prompt += f"USER: {chatbot[-1][0]} ASSISTANT:"
|
98 |
+
data_dict['prompt_after_obj'] = [prompt]
|
99 |
+
|
100 |
+
# anchor orientation
|
101 |
+
anchor_orient = torch.zeros(1, 4).float()
|
102 |
+
anchor_orient[:, -1] = 1
|
103 |
+
data_dict['anchor_orientation'] = anchor_orient
|
104 |
+
|
105 |
+
# load preprocessed scene features
|
106 |
+
data_dict.update(torch.load(os.path.join(OBJ_FEATS_DIR, f'{dropdown_scene}.pth'), map_location='cpu'))
|
107 |
+
|
108 |
+
# inference
|
109 |
+
for k, v in data_dict.items():
|
110 |
+
if isinstance(v, torch.Tensor):
|
111 |
+
data_dict[k] = v.to(DEVICE)
|
112 |
+
|
113 |
+
output = agent.generate(
|
114 |
+
data_dict,
|
115 |
+
repetition_penalty=float(repetition_penalty),
|
116 |
+
length_penalty=float(length_penalty),
|
117 |
+
)
|
118 |
+
output = output[0]
|
119 |
+
|
120 |
+
# display response
|
121 |
+
for out_len in range(1, len(output)-1):
|
122 |
+
chatbot[-1] = (chatbot[-1][0], output[:out_len] + '▌')
|
123 |
+
yield (chatbot,) + (DISABLE_BUTTON,) * 5
|
124 |
+
time.sleep(0.01)
|
125 |
+
|
126 |
+
chatbot[-1] = (chatbot[-1][0], output)
|
127 |
+
vote_response(chatbot, 'log', dropdown_scene, dropdown_conversation_mode)
|
128 |
+
yield (chatbot,) + (ENABLE_BUTTON,) * 5
|
129 |
+
|
130 |
+
|
131 |
+
def vote_response(
|
132 |
+
chatbot: gr.Chatbot, vote_type: str,
|
133 |
+
dropdown_scene: gr.Dropdown,
|
134 |
+
dropdown_conversation_mode: gr.Dropdown
|
135 |
+
):
|
136 |
+
t = datetime.datetime.now()
|
137 |
+
this_log = {
|
138 |
+
'time': f'{t.hour:02d}:{t.minute:02d}:{t.second:02d}',
|
139 |
+
'type': vote_type,
|
140 |
+
'scene': dropdown_scene,
|
141 |
+
'mode': dropdown_conversation_mode,
|
142 |
+
'dialogue': chatbot,
|
143 |
+
}
|
144 |
+
fname = get_log_fname()
|
145 |
+
if os.path.exists(fname):
|
146 |
+
with open(fname) as f:
|
147 |
+
logs = json.load(f)
|
148 |
+
logs.append(this_log)
|
149 |
+
else:
|
150 |
+
logs = [this_log]
|
151 |
+
with open(fname, 'w') as f:
|
152 |
+
json.dump(logs, f, indent=2)
|
153 |
+
|
154 |
+
|
155 |
+
def upvote_response(
|
156 |
+
chatbot: gr.Chatbot,
|
157 |
+
dropdown_scene: gr.Dropdown,
|
158 |
+
dropdown_conversation_mode: gr.Dropdown
|
159 |
+
):
|
160 |
+
vote_response(chatbot, 'upvote', dropdown_scene, dropdown_conversation_mode)
|
161 |
+
return ("",) + (DISABLE_BUTTON,) * 3
|
162 |
+
|
163 |
+
|
164 |
+
def downvote_response(
|
165 |
+
chatbot: gr.Chatbot,
|
166 |
+
dropdown_scene: gr.Dropdown,
|
167 |
+
dropdown_conversation_mode: gr.Dropdown
|
168 |
+
):
|
169 |
+
vote_response(chatbot, 'downvote', dropdown_scene, dropdown_conversation_mode)
|
170 |
+
return ("",) + (DISABLE_BUTTON,) * 3
|
171 |
+
|
172 |
+
|
173 |
+
def flag_response(
|
174 |
+
chatbot: gr.Chatbot,
|
175 |
+
dropdown_scene: gr.Dropdown,
|
176 |
+
dropdown_conversation_mode: gr.Dropdown
|
177 |
+
):
|
178 |
+
vote_response(chatbot, 'flag', dropdown_scene, dropdown_conversation_mode)
|
179 |
+
return ("",) + (DISABLE_BUTTON,) * 3
|
180 |
+
|
181 |
+
|
182 |
+
def clear_history():
|
183 |
+
# reset chatbot history
|
184 |
+
return (None, "",) + (DISABLE_BUTTON,) * 4
|