|
--- |
|
license: mit |
|
tags: |
|
- vision |
|
- DPO |
|
- RLHF |
|
- preference |
|
- feedback |
|
- reward model |
|
- preference model |
|
--- |
|
|
|
#### Robust Visual Reward Model |
|
Robust visual reward model (RoVRM) is developed through a three-phase progressive training (i.e., pre-training with textual preference data→fine-tuning with image caption-based preference data→fine-tuning with visual preference data), and optimal transport-based selective preference data. |
|
These approaches effectively transfer preferences from auxiliary textual data to enhance the model's robustness. |
|
The repository hosts the RoVRM built on the LLaVA-1.5-7B model. |
|
We employed RoVRM for best-of-$n$ sampling and RL training, demonstrating its capability to significantly improve performance and reduce hallucination in large vision-language models. |
|
Detailed training information and experimental results are available in our [paper](https://arxiv.org/abs/2408.12109). |
|
|
|
 |
|
|
|
## How to use the model |
|
We recommend using the [Vision-LLM-Alignment](https://github.com/wangclnlp/Vision-LLM-Alignment) system to run our RoVRM, as it was also used for its training. |
|
|
|
To evaluate a question-answer pair with RoVRM, follow two steps: |
|
1. Convert the safetensor format model to ```pytorch_model.bin``` by using the ```convert_pytorch_bin.py``` script. |
|
2. Download the Vision-LLM-Alignment repository and run the demo from the first-level directory within the repository. |
|
|
|
```python |
|
from transformers import AutoProcessor |
|
from training.utils.model.third_party_model.hf_model.modeling_llava import LlavaForConditionalGeneration |
|
from torch.utils.data.dataloader import default_collate |
|
from PIL import Image |
|
import copy |
|
import torch |
|
import argparse |
|
import os |
|
|
|
device = torch.device("cuda:0") |
|
from training.utils.model.modeling_reward import create_reward_or_critic_model |
|
|
|
# Set vis_llm_base path and path of the checkpoint |
|
# You need to load the llava-1.5-7b model to build an initialized RoVRM. |
|
base_path = "base_models/llava-1.5-7b-hf" |
|
|
|
# the checkpoint of RoVRM. |
|
ckpt_path = "models/pytorch_model.bin" |
|
|
|
processor = AutoProcessor.from_pretrained(base_path) |
|
image_processor = processor.image_processor |
|
tokenizer = processor.tokenizer |
|
|
|
tokenizer.add_bos_token = True |
|
tokenizer.add_eos_token = True |
|
|
|
args = { |
|
"model_architecture": "llava", |
|
"lang_decoder_update": False, |
|
"from_checkpoint": base_path |
|
} |
|
args = argparse.Namespace(**args) |
|
|
|
model, image_processor, tokenizer = create_reward_or_critic_model( |
|
text_tokenizer=tokenizer, |
|
args=args) |
|
|
|
model.load_state_dict(torch.load(os.path.join(ckpt_path, 'pytorch_model.bin'), map_location='cpu'), strict=False) |
|
model.to(device) |
|
|
|
# Set input sentence and path of the input image |
|
# <image> is necessary when there is an image input |
|
input_sen = "USER: ### Image:<image>\nIdentify and describe each object in the image in detail.\nASSISTANT: In the image, there is a cute, colorful cartoon girl sitting on a chair at a wooden table. She is reading a book, which is a prominent object in the scene. The table and chair are also present, adding to the overall setting. As this is a cartoon-style image, the girl and the book may have a more exaggerated or simplified design compared to real-life objects. " |
|
img_path = "llava1.5_raw_images_00011_000118793.jpg" |
|
|
|
# Load and preprocess the image |
|
image = Image.open(img_path).convert("RGB") |
|
image = image_processor(image) |
|
try: |
|
image = image['pixel_values'][0] |
|
except: |
|
pass |
|
|
|
input_sen = tokenizer(input_sen, |
|
return_tensors=None, |
|
padding="do_not_pad", |
|
truncation=True, |
|
max_length=512,) |
|
|
|
input_sen.update(labels=copy.deepcopy(input_sen["input_ids"])) |
|
input_sen.update(image=image) |
|
|
|
reward_scores = model(img=default_collate(image).reshape((-1,) + image[0].shape[-3:]).unsqueeze(0).to(device), |
|
lang=torch.LongTensor(input_sen["input_ids"]).unsqueeze(0).to(device), |
|
attention_mask=torch.LongTensor(input_sen["attention_mask"]).unsqueeze(0).to(device), |
|
input_labels=torch.LongTensor(input_sen["labels"]).unsqueeze(0).to(device)) |
|
|
|
print(reward_scores[0].item()) |
|
``` |
|
|
|
Please cite our paper if you find RoVRM helpful in your work🌹🌹🌹: |
|
``` |
|
@inproceedings{wang2025rovrm, |
|
title={Rovrm: A robust visual reward model optimized via auxiliary textual preference data}, |
|
author={Wang, Chenglong and Gan, Yang and Huo, Yifu and Mu, Yongyu and Yang, Murun and He, Qiaozhi and Xiao, Tong and Zhang, Chunliang and Liu, Tongran and Zhu, Jingbo}, |
|
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, |
|
volume={39}, |
|
number={24}, |
|
pages={25336--25344}, |
|
year={2025} |
|
} |
|
``` |