|
--- |
|
license: mit |
|
license_link: https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/resolve/main/LICENSE |
|
|
|
language: |
|
- multilingual |
|
pipeline_tag: text-generation |
|
tags: |
|
- nlp |
|
- code |
|
- vision |
|
widget: |
|
- messages: |
|
- role: user |
|
content: <|image_1|>\nWhat action should the robot take to {lang}? |
|
--- |
|
|
|
## TraceVLA-Phi3V |
|
``TraceVLA-Phi3V`` model is a vision-language-action model obtained by finetuning the base OpenVLA-Phi3V Model on the Open X-Embodiment robot mixture dataset with [visual trace prompting](https://arxiv.org/pdf/2412.10345) technique. |
|
|
|
### Results on SimplerEnv Fractal + SimplerEnv: |
|
|
|
#### Fractal: |
|
| Policy/Settings | Pick up Coke | Move near | Open/Close Drawer | Put in Drawer | Average | |
|
|:------:|:------------:|:---------:|:------------:|:-----------:|:-------:| |
|
| (Visual Matching) OpenVLA-Phi3V | **56.7%** | 53.3% | **38.4%** | **15.7%** | **41.0%** | |
|
| (Visual Matching) TraceVLA-Phi3V | **69.7%** | **70.8%** | **35.4%** | 0.% | **44.0%** | |
|
| (Variant Aggregation) OpenVLA-Phi3V | 55.4% | **57.7%** | 19.3% | **10.6%** | 35.8% | |
|
| (Variant Aggregation) TraceVLA-Phi3V | **75.4%** | **67.8%** | **37.5%** | 0.0% | **45.1%** | |
|
|
|
#### Bridge: |
|
| Policy/Settings | Put Spoon | Put Carrot | Stack Block | Put Eggplant | Average | |
|
|:------:|:------------:|:---------:|:------------:|:-----------:|:-------:| |
|
| OpenVLA-Phi3V | **12.5%** | 0% | 0% | 8.3% | 5.2% | |
|
| TraceVLA-Phi3V | 8.3% | 0% | **12.5%** | **66.7%** | **21.9%** | |
|
|
|
|
|
### Sample Inference Code |
|
Here is the sample inference code of OpenVLA-Phi3V. |
|
``` |
|
# Load Processor & VLA |
|
from transformers import AutoModelForCausalLM , AutoProcessor |
|
from PIL import Image |
|
import json |
|
processor = AutoProcessor.from_pretrained( |
|
model_path, trust_remote_code=True, num_crops=1 |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
_attn_implementation='flash_attention_2', |
|
use_cache=False |
|
).cuda() |
|
|
|
# Load Visual Trace Processor |
|
from prismatic.eval import TraceProcessor |
|
trace_processor = TraceProcessor(cotracker_model_path) |
|
|
|
# Load dataset statistics |
|
dataset_stats_dir = os.path.join(model_path, 'dataset_statistics.json') |
|
with open(dataset_stats_dir, 'r') as file: |
|
action_norm_stats = json.load(file)[dataset_name]['action'] |
|
model.prepare_action_inference(action_norm_stats, processor.tokenizer.vocab_size) |
|
|
|
lang: str = None # Task language instruction |
|
### IMPORTANT: Make sure image is of size (336,336) |
|
image: PIL.Image = None # Image observation |
|
|
|
# Get visual trace overlaid image observation |
|
image = resize_image(image, (256,256)) ### 256x256 is the resolution of Co-Tracker Input Resolution |
|
image_overlaid, has_trace = self.trace_processors[i].process_image(image) |
|
image_overlaid = resize_image(image_overlaid, (336,336)) ### 336x336 is the resolution of Phi3V image encoder. |
|
|
|
# Prepare TraceVLA prompt format |
|
if not has_trace: |
|
prompt_message = { |
|
'role': 'user', |
|
'content': f'<|image_1|><|image_2|>\nWhat action should the robot take to {task_description}?', |
|
} |
|
else: |
|
prompt_message = { |
|
'role': 'user', |
|
'content': f'You are given two images: one with the original robot observation <|image_1|>, and another one marked with historial traces of the robot end effector and moving objects <|image_2|>.\nWhat action should the robot take to {task_description}?', |
|
} |
|
prompt = processor.tokenizer.apply_chat_template( |
|
[prompt_message], tokenize=False, add_generation_prompt=True |
|
) |
|
inputs = processor(prompt, [image, image_overlaid]).to("cuda:0", dtype=torch.bfloat16) |
|
|
|
|
|
# Get the action output from model |
|
model.predict_action(**inputs) |
|
``` |
|
|
|
For more examples, including scripts for finetuning OpenVLA-Phi3V models on your own robot demonstration datasets, check out our [repository](https://github.com/FrankZheng2022/tracevla/tree/phi3). |
|
|
|
|
|
|
|
|
|
### Citation |
|
|
|
If you find our code or models useful in your work, please cite [our paper](https://arxiv.org/abs/2412.10345): |
|
|
|
```bibtex |
|
@misc{zheng2024tracevlavisualtraceprompting, |
|
title={TraceVLA: Visual Trace Prompting Enhances Spatial-Temporal Awareness for Generalist Robotic Policies}, |
|
author={Ruijie Zheng and Yongyuan Liang and Shuaiyi Huang and Jianfeng Gao and Hal Daumé III and Andrey Kolobov and Furong Huang and Jianwei Yang}, |
|
year={2024}, |
|
eprint={2412.10345}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.RO}, |
|
url={https://arxiv.org/abs/2412.10345}, |
|
} |
|
``` |