gray311 commited on
Commit
b6c918d
·
verified ·
1 Parent(s): 01cbad9

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +189 -0
README.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Overview
2
+
3
+ The quest for fully autonomous vehicles (AVs) capable of navigating complex real-world scenarios with human-like understanding and responsiveness. In this paper, we introduce Dolphins, a novel vision-language model architected to imbibe human-like driving abilities. Dolphins is adept at processing multimodal inputs comprising video (or image) data, text instructions, and historical control signals to generate informed outputs corresponding to the provided instructions. Building upon the open-sourced pretrained Vision-Language Model, OpenFlamingo, we tailored Dolphins to the driving domain by constructing driving-specific instruction data and conducting instruction tuning. Through the utilization of the BDD-X dataset, we designed and consolidated four distinct AV tasks into Dolphins to foster a holistic understanding of intricate driving scenarios. As a result, the distinctive features of Dolphins are delineated into two dimensions: (1) the ability to provide a comprehensive understanding of complex and long-tailed open-world driving scenarios and solve a spectrum of AV tasks, and (2) the emergence of human-like capabilities including gradient-free rapid learning and adaptation via in-context learning, reflection and error recovery, and interoperability.
4
+
5
+
6
+
7
+ ### Initialization
8
+
9
+ ``` python
10
+ from mllm.src.factory import create_model_and_transforms
11
+ from configs.lora_config import openflamingo_tuning_config
12
+
13
+ peft_config, peft_model_id = None, None
14
+ peft_config = LoraConfig(**openflamingo_tuning_config)
15
+ model, image_processor, tokenizer = create_model_and_transforms(
16
+ clip_vision_encoder_path="ViT-L-14-336",
17
+ clip_vision_encoder_pretrained="openai",
18
+ lang_encoder_path="anas-awadalla/mpt-7b", # anas-awadalla/mpt-7b
19
+ tokenizer_path="anas-awadalla/mpt-7b", # anas-awadalla/mpt-7b
20
+ cross_attn_every_n_layers=4,
21
+ use_peft=True,
22
+ peft_config=peft_config,
23
+ )
24
+
25
+
26
+
27
+ # grab model checkpoint from huggingface hub
28
+ from huggingface_hub import hf_hub_download
29
+ import torch
30
+
31
+ checkpoint_path = hf_hub_download("gray311/Dolphins", "checkpoint. pt")
32
+ model.load_state_dict(torch.load(checkpoint_path), strict=False)
33
+ ```
34
+ ### Generation example
35
+ Below is an example of generating text conditioned on driving videos.
36
+
37
+
38
+
39
+ ```
40
+ import os
41
+ import json
42
+ import argparse
43
+ import pandas as pd
44
+ from tqdm import tqdm
45
+ from typing import Union
46
+ from PIL import Image
47
+ import mimetypes
48
+
49
+ import cv2
50
+
51
+ import torch
52
+ from torch.utils.data import DataLoader
53
+ import transformers
54
+ from transformers import LlamaTokenizer, CLIPImageProcessor
55
+
56
+ from configs.dataset_config import DATASET_CONFIG
57
+ from configs.lora_config import openflamingo_tuning_config, otter_tuning_config
58
+
59
+ from mllm.src.factory import create_model_and_transforms
60
+ from mllm.otter.modeling_otter import OtterConfig, OtterForConditionalGeneration
61
+
62
+ from huggingface_hub import hf_hub_download
63
+ from peft import (
64
+ get_peft_model,
65
+ LoraConfig,
66
+ get_peft_model_state_dict,
67
+ PeftConfig,
68
+ PeftModel
69
+ )
70
+
71
+ def setup_seed(seed):
72
+ torch.manual_seed(seed)
73
+ torch.cuda.manual_seed_all(seed)
74
+ np.random.seed(seed)
75
+ random.seed(seed)
76
+ torch.backends.cudnn.deterministic = True
77
+
78
+ def get_content_type(file_path):
79
+ content_type, _ = mimetypes.guess_type(file_path)
80
+ return content_type
81
+
82
+
83
+ # ------------------- Image and Video Handling Functions -------------------
84
+ def extract_frames(video_path, num_frames=16):
85
+ video = cv2.VideoCapture(video_path)
86
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
87
+ frame_step = total_frames // num_frames
88
+ frames = []
89
+
90
+ for i in range(num_frames):
91
+ video.set(cv2.CAP_PROP_POS_FRAMES, i * frame_step)
92
+ ret, frame = video.read()
93
+ if ret:
94
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
+ frame = Image.fromarray(frame).convert("RGB")
96
+ frames.append(frame)
97
+
98
+ video.release()
99
+ return frames
100
+
101
+
102
+ def get_image(url: str) -> Union[Image.Image, list]:
103
+ if "://" not in url: # Local file
104
+ content_type = get_content_type(url)
105
+ else: # Remote URL
106
+ content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
107
+
108
+ if "image" in content_type:
109
+ if "://" not in url: # Local file
110
+ return Image.open(url)
111
+ else: # Remote URL
112
+ return Image.open(requests.get(url, stream=True, verify=False).raw)
113
+ elif "video" in content_type:
114
+ video_path = "temp_video.mp4"
115
+ if "://" not in url: # Local file
116
+ video_path = url
117
+ else: # Remote URL
118
+ with open(video_path, "wb") as f:
119
+ f.write(requests.get(url, stream=True, verify=False).content)
120
+ frames = extract_frames(video_path)
121
+ if "://" in url: # Only remove the temporary video file if it was downloaded
122
+ os.remove(video_path)
123
+ return frames
124
+ else:
125
+ raise ValueError("Invalid content type. Expected image or video.")
126
+
127
+
128
+ def load_pretrained_modoel():
129
+ peft_config, peft_model_id = None, None
130
+ peft_config = LoraConfig(**openflamingo_tuning_config)
131
+ model, image_processor, tokenizer = create_model_and_transforms(
132
+ clip_vision_encoder_path="ViT-L-14-336",
133
+ clip_vision_encoder_pretrained="openai",
134
+ lang_encoder_path="anas-awadalla/mpt-7b", # anas-awadalla/mpt-7b
135
+ tokenizer_path="anas-awadalla/mpt-7b", # anas-awadalla/mpt-7b
136
+ cross_attn_every_n_layers=4,
137
+ use_peft=True,
138
+ peft_config=peft_config,
139
+ )
140
+
141
+ checkpoint_path = hf_hub_download("gray311/Dolphins", "checkpoint.pt")
142
+ model.load_state_dict(torch.load(checkpoint_path), strict=False)
143
+ model.half().cuda()
144
+
145
+ return model, image_processor, tokenizer
146
+
147
+
148
+ def get_model_inputs(video_path, instruction, model, image_processor, tokenizer):
149
+ frames = get_image(video_path)
150
+ vision_x = torch.stack([image_processor(image) for image in frames], dim=0).unsqueeze(0).unsqueeze(0)
151
+ assert vision_x.shape[2] == len(frames)
152
+ prompt = [
153
+ f"USER: <image> is a driving video. {instruction} GPT:<answer>"
154
+ ]
155
+ inputs = tokenizer(prompt, return_tensors="pt", ).to(model.device)
156
+
157
+ return vision_x, inputs
158
+
159
+ if __name__ == "__main__":
160
+
161
+ video_path = "path/to/your/video"
162
+ instruction = "Please describe this video in detail."
163
+
164
+ model, image_processor, tokenizer = load_pretrained_modoel()
165
+ vision_x, inputs = get_model_inputs(video_path, instruction, model, image_processor, tokenizer)
166
+ generation_kwargs = {'max_new_tokens': 512, 'temperature': 1,
167
+ 'top_k': 0, 'top_p': 1, 'no_repeat_ngram_size': 3, 'length_penalty': 1,
168
+ 'do_sample': False,
169
+ 'early_stopping': True}
170
+
171
+ generated_tokens = model.generate(
172
+ vision_x=vision_x.half().cuda(),
173
+ lang_x=inputs["input_ids"].cuda(),
174
+ attention_mask=inputs["attention_mask"].cuda(),
175
+ num_beams=3,
176
+ **generation_kwargs,
177
+ )
178
+
179
+ generated_tokens = generated_tokens.cpu().numpy()
180
+ if isinstance(generated_tokens, tuple):
181
+ generated_tokens = generated_tokens[0]
182
+
183
+ generated_text = tokenizer.batch_decode(generated_tokens)
184
+
185
+ print(
186
+ f"Dolphin output:\n\n{generated_text}"
187
+ )
188
+ ```
189
+