CodeGoat24 commited on
Commit
77c5232
·
verified ·
1 Parent(s): 36e3e1c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +91 -3
README.md CHANGED
@@ -1,3 +1,91 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ base_model:
4
+ - lmms-lab/LLaVA-Video-7B-Qwen2
5
+ ---
6
+
7
+ # LLaVA-Video-7B-Qwen2-UnifiedReward-DPO
8
+
9
+ ## Model Summary
10
+
11
+ This model is trained on LLaVA-Video-7B-Qwen2 based on DPO preference data constructed by our [UnifiedReward-7B](https://huggingface.co/CodeGoat24/UnifiedReward-7b) for enhanced video understanding ability.
12
+
13
+ For further details, please refer to the following resources:
14
+ - 📰 Paper:
15
+ - 🪐 Project Page: https://codegoat24.github.io/UnifiedReward/
16
+ - 🤗 Model Collections: https://huggingface.co/collections/CodeGoat24/unifiedreward-models-67c3008148c3a380d15ac63a
17
+ - 🤗 Dataset Collections: https://huggingface.co/collections/CodeGoat24/unifiedreward-training-data-67c300d4fd5eff00fa7f1ede
18
+ - 👋 Point of Contact: [Yibin Wang](https://codegoat24.github.io)
19
+
20
+
21
+ ### Quick Start
22
+
23
+ ~~~python
24
+ # pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
25
+ from llava.model.builder import load_pretrained_model
26
+ from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
27
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
28
+ from llava.conversation import conv_templates, SeparatorStyle
29
+ from PIL import Image
30
+ import requests
31
+ import copy
32
+ import torch
33
+ import sys
34
+ import warnings
35
+ from decord import VideoReader, cpu
36
+ import numpy as np
37
+ warnings.filterwarnings("ignore")
38
+ def load_video(video_path, max_frames_num,fps=1,force_sample=False):
39
+ if max_frames_num == 0:
40
+ return np.zeros((1, 336, 336, 3))
41
+ vr = VideoReader(video_path, ctx=cpu(0),num_threads=1)
42
+ total_frame_num = len(vr)
43
+ video_time = total_frame_num / vr.get_avg_fps()
44
+ fps = round(vr.get_avg_fps()/fps)
45
+ frame_idx = [i for i in range(0, len(vr), fps)]
46
+ frame_time = [i/fps for i in frame_idx]
47
+ if len(frame_idx) > max_frames_num or force_sample:
48
+ sample_fps = max_frames_num
49
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
50
+ frame_idx = uniform_sampled_frames.tolist()
51
+ frame_time = [i/vr.get_avg_fps() for i in frame_idx]
52
+ frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
53
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
54
+ # import pdb;pdb.set_trace()
55
+ return spare_frames,frame_time,video_time
56
+ pretrained = "CodeGoat24/LLaVA-Video-7B-Qwen2-UnifiedReward-DPO"
57
+ model_name = "llava_qwen"
58
+ device = "cuda"
59
+ device_map = "auto"
60
+ tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map) # Add any other thing you want to pass in llava_model_args
61
+ model.eval()
62
+ video_path = "XXXX"
63
+ max_frames_num = 64
64
+ video,frame_time,video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
65
+ video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].cuda().half()
66
+ video = [video]
67
+ conv_template = "qwen_1_5" # Make sure you use correct chat template for different models
68
+ question = DEFAULT_IMAGE_TOKEN + "\nPlease describe this video in detail."
69
+ conv = copy.deepcopy(conv_templates[conv_template])
70
+ conv.append_message(conv.roles[0], question)
71
+ conv.append_message(conv.roles[1], None)
72
+ prompt_question = conv.get_prompt()
73
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
74
+ cont = model.generate(
75
+ input_ids,
76
+ images=video,
77
+ modalities= ["video"],
78
+ do_sample=False,
79
+ temperature=0,
80
+ max_new_tokens=4096,
81
+ )
82
+ text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip()
83
+ print(text_outputs)
84
+ ~~~
85
+
86
+
87
+ ## Citation
88
+
89
+ ```
90
+
91
+ ```