wangclnlp commited on
Commit
92dff2a
·
verified ·
1 Parent(s): 7186f2d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +110 -110
README.md CHANGED
@@ -1,111 +1,111 @@
1
- ---
2
- license: mit
3
- tags:
4
- - vision
5
- - DPO
6
- - RLHF
7
- - preference
8
- - feedback
9
- - reward model
10
- - preference model
11
- ---
12
-
13
- #### Robust Visual Reward Model
14
- Robust visual reward model (RoVRM) is developed through a three-phase progressive training (i.e., pre-training with textual preference data→fine-tuning with image caption-based preference data→fine-tuning with visual preference data), and optimal transport-based selective preference data.
15
- These approaches effectively transfer preferences from auxiliary textual data to enhance the model's robustness.
16
- The repository hosts the RoVRM built on the LLaVA-1.5-7B model.
17
- We employed RoVRM for best-of-$n$ sampling and RL training, demonstrating its capability to significantly improve performance and reduce hallucination in large vision-language models.
18
- Detailed training information and experimental results are available in our [paper](https://arxiv.org/abs/2408.12109).
19
-
20
- ![alt text](figure/main_image.png)
21
-
22
- ## How to use the model
23
- We recommend using the [Vision-LLM-Alignment](https://github.com/wangclnlp/Vision-LLM-Alignment) system to run our RoVRM, as it was also used for its training.
24
-
25
- To evaluate a question-answer pair with RoVRM, follow two steps:
26
- 1. Convert the safetensor format model to ```pytorch_model.bin``` by using the ```convert_pytorch_bin.py``` script.
27
- 2. Download the Vision-LLM-Alignment repository and run the demo from the first-level directory within the repository.
28
-
29
- ```python
30
- from transformers import AutoProcessor
31
- from training.utils.model.third_party_model.hf_model.modeling_llava import LlavaForConditionalGeneration
32
- from torch.utils.data.dataloader import default_collate
33
- from PIL import Image
34
- import copy
35
- import torch
36
- import argparse
37
- import os
38
-
39
- device = torch.device("cuda:0")
40
- from training.utils.model.modeling_reward import create_reward_or_critic_model
41
-
42
- # Set vis_llm_base path and path of the checkpoint
43
- # You need to load the llava-1.5-7b model to build an initialized RoVRM.
44
- base_path = "base_models/llava-1.5-7b-hf"
45
-
46
- # the checkpoint of RoVRM.
47
- ckpt_path = "models/pytorch_model.bin"
48
-
49
- processor = AutoProcessor.from_pretrained(base_path)
50
- image_processor = processor.image_processor
51
- tokenizer = processor.tokenizer
52
-
53
- tokenizer.add_bos_token = True
54
- tokenizer.add_eos_token = True
55
-
56
- args = {
57
- "model_architecture": "llava",
58
- "lang_decoder_update": False,
59
- "from_checkpoint": base_path
60
- }
61
- args = argparse.Namespace(**args)
62
-
63
- model, image_processor, tokenizer = create_reward_or_critic_model(
64
- text_tokenizer=tokenizer,
65
- args=args)
66
-
67
- model.load_state_dict(torch.load(os.path.join(ckpt_path, 'pytorch_model.bin'), map_location='cpu'), strict=False)
68
- model.to(device)
69
-
70
- # Set input sentence and path of the input image
71
- # <image> is necessary when there is an image input
72
- input_sen = "USER: ### Image:<image>\nIdentify and describe each object in the image in detail.\nASSISTANT: In the image, there is a cute, colorful cartoon girl sitting on a chair at a wooden table. She is reading a book, which is a prominent object in the scene. The table and chair are also present, adding to the overall setting. As this is a cartoon-style image, the girl and the book may have a more exaggerated or simplified design compared to real-life objects. "
73
- img_path = "llava1.5_raw_images_00011_000118793.jpg"
74
-
75
- # Load and preprocess the image
76
- image = Image.open(img_path).convert("RGB")
77
- image = image_processor(image)
78
- try:
79
- image = image['pixel_values'][0]
80
- except:
81
- pass
82
-
83
- input_sen = tokenizer(input_sen,
84
- return_tensors=None,
85
- padding="do_not_pad",
86
- truncation=True,
87
- max_length=512,)
88
-
89
- input_sen.update(labels=copy.deepcopy(input_sen["input_ids"]))
90
- input_sen.update(image=image)
91
-
92
- reward_scores = model(img=default_collate(image).reshape((-1,) + image[0].shape[-3:]).unsqueeze(0).to(device),
93
- lang=torch.LongTensor(input_sen["input_ids"]).unsqueeze(0).to(device),
94
- attention_mask=torch.LongTensor(input_sen["attention_mask"]).unsqueeze(0).to(device),
95
- input_labels=torch.LongTensor(input_sen["labels"]).unsqueeze(0).to(device))
96
-
97
- print(reward_scores[0].item())
98
- ```
99
-
100
- Please cite our paper if you find RoVRM helpful in your work🌹🌹🌹:
101
- ```bash
102
- @misc{wang2024rovrmrobustvisualreward,
103
- title={RoVRM: A Robust Visual Reward Model Optimized via Auxiliary Textual Preference Data},
104
- author={Chenglong Wang and Yang Gan and Yifu Huo and Yongyu Mu and Murun Yang and Qiaozhi He and Tong Xiao and Chunliang Zhang and Tongran Liu and Quan Du and Di Yang and Jingbo Zhu},
105
- year={2024},
106
- eprint={2408.12109},
107
- archivePrefix={arXiv},
108
- primaryClass={cs.CV},
109
- url={https://arxiv.org/abs/2408.12109},
110
- }
111
  ```
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - vision
5
+ - DPO
6
+ - RLHF
7
+ - preference
8
+ - feedback
9
+ - reward model
10
+ - preference model
11
+ ---
12
+
13
+ #### Robust Visual Reward Model
14
+ Robust visual reward model (RoVRM) is developed through a three-phase progressive training (i.e., pre-training with textual preference data→fine-tuning with image caption-based preference data→fine-tuning with visual preference data), and optimal transport-based selective preference data.
15
+ These approaches effectively transfer preferences from auxiliary textual data to enhance the model's robustness.
16
+ The repository hosts the RoVRM built on the LLaVA-1.5-7B model.
17
+ We employed RoVRM for best-of-$n$ sampling and RL training, demonstrating its capability to significantly improve performance and reduce hallucination in large vision-language models.
18
+ Detailed training information and experimental results are available in our [paper](https://arxiv.org/abs/2408.12109).
19
+
20
+ ![alt text](figure/main_image.png)
21
+
22
+ ## How to use the model
23
+ We recommend using the [Vision-LLM-Alignment](https://github.com/wangclnlp/Vision-LLM-Alignment) system to run our RoVRM, as it was also used for its training.
24
+
25
+ To evaluate a question-answer pair with RoVRM, follow two steps:
26
+ 1. Convert the safetensor format model to ```pytorch_model.bin``` by using the ```convert_pytorch_bin.py``` script.
27
+ 2. Download the Vision-LLM-Alignment repository and run the demo from the first-level directory within the repository.
28
+
29
+ ```python
30
+ from transformers import AutoProcessor
31
+ from training.utils.model.third_party_model.hf_model.modeling_llava import LlavaForConditionalGeneration
32
+ from torch.utils.data.dataloader import default_collate
33
+ from PIL import Image
34
+ import copy
35
+ import torch
36
+ import argparse
37
+ import os
38
+
39
+ device = torch.device("cuda:0")
40
+ from training.utils.model.modeling_reward import create_reward_or_critic_model
41
+
42
+ # Set vis_llm_base path and path of the checkpoint
43
+ # You need to load the llava-1.5-7b model to build an initialized RoVRM.
44
+ base_path = "base_models/llava-1.5-7b-hf"
45
+
46
+ # the checkpoint of RoVRM.
47
+ ckpt_path = "models/pytorch_model.bin"
48
+
49
+ processor = AutoProcessor.from_pretrained(base_path)
50
+ image_processor = processor.image_processor
51
+ tokenizer = processor.tokenizer
52
+
53
+ tokenizer.add_bos_token = True
54
+ tokenizer.add_eos_token = True
55
+
56
+ args = {
57
+ "model_architecture": "llava",
58
+ "lang_decoder_update": False,
59
+ "from_checkpoint": base_path
60
+ }
61
+ args = argparse.Namespace(**args)
62
+
63
+ model, image_processor, tokenizer = create_reward_or_critic_model(
64
+ text_tokenizer=tokenizer,
65
+ args=args)
66
+
67
+ model.load_state_dict(torch.load(os.path.join(ckpt_path, 'pytorch_model.bin'), map_location='cpu'), strict=False)
68
+ model.to(device)
69
+
70
+ # Set input sentence and path of the input image
71
+ # <image> is necessary when there is an image input
72
+ input_sen = "USER: ### Image:<image>\nIdentify and describe each object in the image in detail.\nASSISTANT: In the image, there is a cute, colorful cartoon girl sitting on a chair at a wooden table. She is reading a book, which is a prominent object in the scene. The table and chair are also present, adding to the overall setting. As this is a cartoon-style image, the girl and the book may have a more exaggerated or simplified design compared to real-life objects. "
73
+ img_path = "llava1.5_raw_images_00011_000118793.jpg"
74
+
75
+ # Load and preprocess the image
76
+ image = Image.open(img_path).convert("RGB")
77
+ image = image_processor(image)
78
+ try:
79
+ image = image['pixel_values'][0]
80
+ except:
81
+ pass
82
+
83
+ input_sen = tokenizer(input_sen,
84
+ return_tensors=None,
85
+ padding="do_not_pad",
86
+ truncation=True,
87
+ max_length=512,)
88
+
89
+ input_sen.update(labels=copy.deepcopy(input_sen["input_ids"]))
90
+ input_sen.update(image=image)
91
+
92
+ reward_scores = model(img=default_collate(image).reshape((-1,) + image[0].shape[-3:]).unsqueeze(0).to(device),
93
+ lang=torch.LongTensor(input_sen["input_ids"]).unsqueeze(0).to(device),
94
+ attention_mask=torch.LongTensor(input_sen["attention_mask"]).unsqueeze(0).to(device),
95
+ input_labels=torch.LongTensor(input_sen["labels"]).unsqueeze(0).to(device))
96
+
97
+ print(reward_scores[0].item())
98
+ ```
99
+
100
+ Please cite our paper if you find RoVRM helpful in your work🌹🌹🌹:
101
+ ```
102
+ @inproceedings{wang2025rovrm,
103
+ title={Rovrm: A robust visual reward model optimized via auxiliary textual preference data},
104
+ author={Wang, Chenglong and Gan, Yang and Huo, Yifu and Mu, Yongyu and Yang, Murun and He, Qiaozhi and Xiao, Tong and Zhang, Chunliang and Liu, Tongran and Zhu, Jingbo},
105
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
106
+ volume={39},
107
+ number={24},
108
+ pages={25336--25344},
109
+ year={2025}
110
+ }
111
  ```