X-iZhang commited on
Commit
23c9ef8
·
verified ·
1 Parent(s): c86e5df

Upload 27 files

Browse files
libra/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LibraLlamaForCausalLM
libra/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
libra/conversation.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+ LLAMA_3 = auto()
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Conversation:
18
+ """A class that keeps all conversation history."""
19
+ system: str
20
+ roles: List[str]
21
+ messages: List[List[str]]
22
+ offset: int
23
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
24
+ sep: str = "###"
25
+ sep2: str = None
26
+ version: str = "Unknown"
27
+
28
+ skip_next: bool = False
29
+
30
+ def get_prompt(self):
31
+ messages = self.messages
32
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
33
+ messages = self.messages.copy()
34
+ init_role, init_msg = messages[0].copy()
35
+ init_msg = init_msg[0].replace("<image>", "").strip()
36
+ if 'mmtag' in self.version:
37
+ messages[0] = (init_role, init_msg)
38
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
39
+ messages.insert(1, (self.roles[1], "Received."))
40
+ else:
41
+ messages[0] = (init_role, "<image>\n" + init_msg)
42
+
43
+ if self.sep_style == SeparatorStyle.SINGLE:
44
+ ret = self.system + self.sep
45
+ for role, message in messages:
46
+ if message:
47
+ if type(message) is tuple:
48
+ message, _, _ = message
49
+ ret += role + ": " + message + self.sep
50
+ else:
51
+ ret += role + ":"
52
+
53
+ elif self.sep_style == SeparatorStyle.TWO:
54
+ seps = [self.sep, self.sep2]
55
+ ret = self.system + seps[0]
56
+ for i, (role, message) in enumerate(messages):
57
+ if message:
58
+ if type(message) is tuple:
59
+ message, _, _ = message
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+
64
+ elif self.sep_style == SeparatorStyle.MPT:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + message + self.sep
71
+ else:
72
+ ret += role
73
+
74
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
75
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
76
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
77
+ ret = ""
78
+
79
+ for i, (role, message) in enumerate(messages):
80
+ if i == 0:
81
+ assert message, "first message should not be none"
82
+ assert role == self.roles[0], "first message should come from user"
83
+ if message:
84
+ if type(message) is tuple:
85
+ message, _, _ = message
86
+ if i == 0:
87
+ message = wrap_sys(self.system) + message
88
+ if i % 2 == 0:
89
+ message = wrap_inst(message)
90
+ ret += self.sep + message
91
+ else:
92
+ ret += " " + message + " " + self.sep2
93
+ else:
94
+ ret += ""
95
+
96
+ ret = ret.lstrip(self.sep)
97
+
98
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
99
+ wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}{self.sep2}" if len(msg) > 0 else ""
100
+ wrap_role = lambda role, msg: f"<|start_header_id|>{role}<|end_header_id|>\n\n{msg}{self.sep2}\n"
101
+
102
+ ret = "" # "<|begin_of_text|>"
103
+
104
+ for i, (role, message) in enumerate(messages):
105
+ if i == 0:
106
+ assert message, "first message should not be none"
107
+ assert role == self.roles[0], "first message should come from user"
108
+
109
+ if message:
110
+ if isinstance(message, tuple):
111
+ message, _, _ = message
112
+
113
+
114
+ if i == 0:
115
+ ret += wrap_sys(self.system)
116
+ ret += wrap_role("user", message)
117
+ else:
118
+ role_name = "user" if role == self.roles[0] else "assistant"
119
+ ret += wrap_role(role_name, message)
120
+ else:
121
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
122
+
123
+ ret = ret.strip()
124
+
125
+ elif self.sep_style == SeparatorStyle.PLAIN:
126
+ seps = [self.sep, self.sep2]
127
+ ret = self.system
128
+ for i, (role, message) in enumerate(messages):
129
+ if message:
130
+ if type(message) is tuple:
131
+ message, _, _ = message
132
+ ret += message + seps[i % 2]
133
+ else:
134
+ ret += ""
135
+ else:
136
+ raise ValueError(f"Invalid style: {self.sep_style}")
137
+
138
+ return ret
139
+
140
+ def append_message(self, role, message):
141
+ self.messages.append([role, message])
142
+
143
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
144
+ if image_process_mode == "Pad":
145
+ def expand2square(pil_img, background_color=(0, 0, 0)):
146
+ width, height = pil_img.size
147
+ if width == height:
148
+ return pil_img
149
+ elif width > height:
150
+ result = Image.new(pil_img.mode, (width, width), background_color)
151
+ result.paste(pil_img, (0, (width - height) // 2))
152
+ return result
153
+ else:
154
+ result = Image.new(pil_img.mode, (height, height), background_color)
155
+ result.paste(pil_img, ((height - width) // 2, 0))
156
+ return result
157
+ image = expand2square(image)
158
+ elif image_process_mode in ["Default", "Crop"]:
159
+ pass
160
+ elif image_process_mode == "Resize":
161
+ image = image.resize((518, 518))
162
+ else:
163
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
164
+ if max(image.size) > max_len:
165
+ max_hw, min_hw = max(image.size), min(image.size)
166
+ aspect_ratio = max_hw / min_hw
167
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
168
+ longest_edge = int(shortest_edge * aspect_ratio)
169
+ W, H = image.size
170
+ if H > W:
171
+ H, W = longest_edge, shortest_edge
172
+ else:
173
+ H, W = shortest_edge, longest_edge
174
+ image = image.resize((W, H))
175
+ if return_pil:
176
+ return image
177
+ else:
178
+ buffered = BytesIO()
179
+ image.save(buffered, format=image_format)
180
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
181
+ return img_b64_str
182
+
183
+ def get_images(self, return_pil=False):
184
+ images = []
185
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
186
+ if i % 2 == 0:
187
+ if type(msg) is tuple:
188
+ msg, image, image_process_mode = msg
189
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
190
+ images.append(image)
191
+ return images
192
+
193
+
194
+ def copy(self):
195
+ return Conversation(
196
+ system=self.system,
197
+ roles=self.roles,
198
+ messages=[[x, y] for x, y in self.messages],
199
+ offset=self.offset,
200
+ sep_style=self.sep_style,
201
+ sep=self.sep,
202
+ sep2=self.sep2,
203
+ version=self.version)
204
+
205
+ def dict(self):
206
+ if len(self.get_images()) > 0:
207
+ return {
208
+ "system": self.system,
209
+ "roles": self.roles,
210
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
211
+ "offset": self.offset,
212
+ "sep": self.sep,
213
+ "sep2": self.sep2,
214
+ }
215
+ return {
216
+ "system": self.system,
217
+ "roles": self.roles,
218
+ "messages": self.messages,
219
+ "offset": self.offset,
220
+ "sep": self.sep,
221
+ "sep2": self.sep2,
222
+ }
223
+
224
+
225
+ conv_mpt = Conversation(
226
+ system="""<|im_start|>system
227
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
228
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
229
+ version="mpt",
230
+ messages=(),
231
+ offset=0,
232
+ sep_style=SeparatorStyle.MPT,
233
+ sep="<|im_end|>",
234
+ )
235
+
236
+ conv_vicuna_v0 = Conversation(
237
+ system="A chat between a curious human and an artificial intelligence assistant. "
238
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
239
+ roles=("Human", "Assistant"),
240
+ messages=(
241
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
242
+ ("Assistant",
243
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
244
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
245
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
246
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
247
+ "renewable and non-renewable energy sources:\n"
248
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
249
+ "energy sources are finite and will eventually run out.\n"
250
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
251
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
252
+ "and other negative effects.\n"
253
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
254
+ "have lower operational costs than non-renewable sources.\n"
255
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
256
+ "locations than non-renewable sources.\n"
257
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
258
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
259
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
260
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
261
+ ),
262
+ offset=2,
263
+ sep_style=SeparatorStyle.SINGLE,
264
+ sep="###",
265
+ )
266
+
267
+ conv_vicuna_v1 = Conversation(
268
+ system="A chat between a curious user and an artificial intelligence assistant. "
269
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
270
+ roles=("USER", "ASSISTANT"),
271
+ version="v1",
272
+ messages=(),
273
+ offset=0,
274
+ sep_style=SeparatorStyle.TWO,
275
+ sep=" ",
276
+ sep2="</s>",
277
+ )
278
+
279
+ conv_llama_2 = Conversation(
280
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
281
+
282
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
283
+ roles=("USER", "ASSISTANT"),
284
+ version="llama_v2",
285
+ messages=(),
286
+ offset=0,
287
+ sep_style=SeparatorStyle.LLAMA_2,
288
+ sep="<s>",
289
+ sep2="</s>",
290
+ )
291
+
292
+ conv_llama_3 = Conversation(
293
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
294
+
295
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
296
+ roles=("USER", "ASSISTANT"),
297
+ version="llama_v3",
298
+ messages=(),
299
+ offset=0,
300
+ sep_style=SeparatorStyle.LLAMA_3,
301
+ sep=" ",
302
+ sep2="<|eot_id|>",
303
+ )
304
+
305
+ conv_libra_llama_2 = Conversation(
306
+ system="You are a helpful language and vision assistant. "
307
+ "You are able to understand the visual content that the user provides, "
308
+ "and assist the user with a variety of tasks using natural language.",
309
+ roles=("USER", "ASSISTANT"),
310
+ version="llama_v2",
311
+ messages=(),
312
+ offset=0,
313
+ sep_style=SeparatorStyle.LLAMA_2,
314
+ sep="<s>",
315
+ sep2="</s>",
316
+ )
317
+
318
+ conv_libra_llama_3 = Conversation(
319
+ system="You are a helpful language and vision assistant. "
320
+ "You are able to understand the visual content that the user provides, "
321
+ "and assist the user with a variety of tasks using natural language.",
322
+ roles=("USER", "ASSISTANT"),
323
+ version="llama_v3",
324
+ messages=(),
325
+ offset=0,
326
+ sep_style=SeparatorStyle.LLAMA_3,
327
+ sep=" ",
328
+ sep2="<|eot_id|>",
329
+ )
330
+
331
+ conv_libra_plain = Conversation(
332
+ system="",
333
+ roles=("", ""),
334
+ messages=(
335
+ ),
336
+ offset=0,
337
+ sep_style=SeparatorStyle.PLAIN,
338
+ sep="\n",
339
+ )
340
+
341
+ conv_libra_v0 = Conversation(
342
+ system="A chat between a curious human and an artificial intelligence assistant. "
343
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
344
+ roles=("Human", "Assistant"),
345
+ messages=(
346
+ ),
347
+ offset=0,
348
+ sep_style=SeparatorStyle.SINGLE,
349
+ sep="###",
350
+ )
351
+
352
+ conv_libra_v0_mmtag = Conversation(
353
+ system="A chat between a curious user and an artificial intelligence assistant. "
354
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
355
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
356
+ roles=("Human", "Assistant"),
357
+ messages=(
358
+ ),
359
+ offset=0,
360
+ sep_style=SeparatorStyle.SINGLE,
361
+ sep="###",
362
+ version="v0_mmtag",
363
+ )
364
+
365
+ conv_libra_v1 = Conversation(
366
+ system="The assistant specialized in comparing Chest X-ray images, identifying differences, and noting temporal changes.",
367
+ roles=("USER", "ASSISTANT"),
368
+ version="v1",
369
+ messages=(),
370
+ offset=0,
371
+ sep_style=SeparatorStyle.TWO,
372
+ sep=" ",
373
+ sep2="</s>",
374
+ )
375
+
376
+
377
+ conv_libra_v1_mmtag = Conversation(
378
+ system="A chat between a curious user and an artificial intelligence assistant. "
379
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
380
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
381
+ roles=("USER", "ASSISTANT"),
382
+ messages=(),
383
+ offset=0,
384
+ sep_style=SeparatorStyle.TWO,
385
+ sep=" ",
386
+ sep2="</s>",
387
+ version="v1_mmtag",
388
+ )
389
+
390
+ default_conversation = conv_vicuna_v1
391
+ conv_templates = {
392
+ "default": conv_libra_v1,
393
+
394
+ "v0": conv_vicuna_v0,
395
+ "v1": conv_vicuna_v1,
396
+ "vicuna_v1": conv_vicuna_v1,
397
+
398
+ "plain": conv_libra_plain,
399
+ "libra_v0": conv_libra_v0,
400
+ "libra_v1": conv_libra_v1,
401
+
402
+ "libra_v0_mmtag": conv_libra_v0_mmtag,
403
+ "libra_v1_mmtag": conv_libra_v1_mmtag,
404
+
405
+ "llama_2": conv_llama_2,
406
+ "libra_llama_2": conv_libra_llama_2,
407
+
408
+ "llama_3": conv_llama_3,
409
+ "libra_llama_3": conv_libra_llama_3,
410
+
411
+ "mpt": conv_mpt,
412
+
413
+ }
414
+
415
+ if __name__ == "__main__":
416
+ print(default_conversation.get_prompt())
libra/eval/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ try:
2
+ from .run_libra import libra_eval
3
+ from .temporal_f1 import temporal_f1_score
4
+ from .radiology_report import evaluate_report
5
+ except:
6
+ pass
libra/eval/eval_vqa_libra.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+ import numpy as np
8
+ import re
9
+
10
+ from libra.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
11
+ from libra.conversation import conv_templates, SeparatorStyle
12
+ from libra.model.builder import load_pretrained_model
13
+ from libra.utils import disable_torch_init
14
+ from libra.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, KeywordsStoppingCriteria
15
+
16
+ import math
17
+ import pydicom
18
+ from PIL import Image
19
+ from io import BytesIO
20
+ from pydicom.pixel_data_handlers.util import apply_voi_lut
21
+
22
+ def split_list(lst, n):
23
+ """Split a list into n (roughly) equal-sized chunks"""
24
+ chunk_size = math.ceil(len(lst) / n) # integer division
25
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
26
+
27
+
28
+ def get_chunk(lst, n, k):
29
+ chunks = split_list(lst, n)
30
+ return chunks[k]
31
+
32
+ def load_images(image_file):
33
+ """
34
+ Load an image from a local file, a URL, or a DICOM file.
35
+
36
+ Args:
37
+ image_file (str): The path or URL of the image file to load.
38
+
39
+ Returns:
40
+ PIL.Image.Image: The loaded image in RGB format.
41
+
42
+ Raises:
43
+ ValueError: If the DICOM file does not contain image data.
44
+ TypeError: If the input is neither a valid file path nor a URL.
45
+ """
46
+ if isinstance(image_file, str):
47
+ # Case 1: Load from URL
48
+ if image_file.startswith(('http://', 'https://')):
49
+ try:
50
+ response = requests.get(image_file)
51
+ response.raise_for_status()
52
+ image = Image.open(BytesIO(response.content)).convert('RGB')
53
+ except Exception as e:
54
+ raise ValueError(f"Error loading image from URL: {image_file}\n{e}")
55
+
56
+ # Case 2: Load from DICOM file
57
+ elif image_file.lower().endswith('.dcm'):
58
+ try:
59
+ dicom = pydicom.dcmread(image_file)
60
+ if 'PixelData' in dicom:
61
+ data = apply_voi_lut(dicom.pixel_array, dicom)
62
+
63
+ # Handle MONOCHROME1 images
64
+ if dicom.PhotometricInterpretation == "MONOCHROME1":
65
+ data = np.max(data) - data
66
+
67
+ # Normalize the image data
68
+ data = data - np.min(data)
69
+ data = data / np.max(data)
70
+ data = (data * 255).astype(np.uint8)
71
+
72
+ # Convert to 3-channel RGB if necessary
73
+ if data.ndim == 2:
74
+ data = np.stack([data] * 3, axis=-1)
75
+
76
+ image = Image.fromarray(data).convert('RGB')
77
+ else:
78
+ raise ValueError("DICOM file does not contain image data")
79
+ except Exception as e:
80
+ raise ValueError(f"Error loading DICOM file: {image_file}\n{e}")
81
+
82
+ # Case 3: Load standard image files (e.g., PNG, JPG)
83
+ else:
84
+ try:
85
+ image = Image.open(image_file).convert('RGB')
86
+ except Exception as e:
87
+ raise ValueError(f"Error loading standard image file: {image_file}\n{e}")
88
+
89
+ else:
90
+ raise TypeError("image_file must be a string representing a file path or URL")
91
+
92
+ return image
93
+
94
+ def get_image_tensors(image_file, image_folder, image_processor, model, device='cuda'):
95
+ # Load and preprocess the images
96
+ if isinstance(image_file, str):
97
+ image = []
98
+ image_path = os.path.join(image_folder, image_file)
99
+ img = load_images(image_path)
100
+ image.append(img)
101
+ elif isinstance(image_file, (list, tuple)):
102
+ image = []
103
+ image_paths = [os.path.join(image_folder, file_name) for file_name in image_file]
104
+ for path in image_paths:
105
+ img = load_images(path)
106
+ image.append(img)
107
+ else:
108
+ raise TypeError("image_file must be a string or a str/list/tuple of strings")
109
+
110
+ # Ensure two images are present
111
+ if len(image) != 2:
112
+ image.append(image[0])
113
+ if model.config.mm_projector_type == "TAC":
114
+ print("Contains only current image. Adding a dummy prior image for TAC.")
115
+
116
+ # Process each image
117
+ processed_images = []
118
+ for img_data in image:
119
+ image_temp = process_images([img_data], image_processor, model.config)[0]
120
+ image_temp = image_temp.to(device=device, non_blocking=True)
121
+ processed_images.append(image_temp)
122
+
123
+ # Separate current and prior images
124
+ cur_images = [processed_images[0]]
125
+ prior_images = [processed_images[1]]
126
+
127
+ # Stack and return as batched tensor
128
+ batch_images = torch.stack([torch.stack(cur_images), torch.stack(prior_images)])
129
+
130
+ return batch_images
131
+
132
+ def eval_model(args):
133
+ """
134
+ Evaluate a pre-trained model on a set of questions and images.
135
+ Args:
136
+ args (Namespace): A namespace object containing the following attributes:
137
+ - model_path (str): Path to the pre-trained model.
138
+ - model_base (str): Base model name.
139
+ - question_file (str): Path to the JSON file containing questions.
140
+ - num_chunks (int): Number of chunks to split the questions into.
141
+ - chunk_idx (int): Index of the chunk to process.
142
+ - answers_file (str): Path to the file where answers will be saved.
143
+ - image_folder (str): Folder containing the images.
144
+ - conv_mode (str): Conversation mode to use.
145
+ - temperature (float): Sampling temperature for generation.
146
+ - top_p (float): Top-p sampling parameter.
147
+ - num_beams (int): Number of beams for beam search.
148
+ - max_new_tokens (int): Maximum number of new tokens to generate.
149
+ - length_penalty (float): Length penalty for beam search.
150
+ - num_return_sequences (int): Number of sequences to return.
151
+ Raises:
152
+ TypeError: If `image_file` is neither a string nor a list/tuple of strings.
153
+ Returns:
154
+ None
155
+ """
156
+ # Model
157
+ disable_torch_init()
158
+ model_path = os.path.expanduser(args.model_path)
159
+ model_name = get_model_name_from_path(model_path)
160
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
161
+
162
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
163
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
164
+ answers_file = os.path.expanduser(args.answers_file)
165
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
166
+ ans_file = open(answers_file, "w")
167
+
168
+ for line in tqdm(questions):
169
+ idx = line["question_id"]
170
+ image_file = line["image"]
171
+ qs = line["text"]
172
+ cur_prompt = qs
173
+ if model.config.mm_use_im_start_end:
174
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
175
+ else:
176
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
177
+
178
+ conv = conv_templates[args.conv_mode].copy()
179
+ conv.append_message(conv.roles[0], qs)
180
+ conv.append_message(conv.roles[1], None)
181
+ prompt = conv.get_prompt()
182
+
183
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
184
+
185
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
186
+ pad_token_id = tokenizer.pad_token_id
187
+
188
+ image_tensors = get_image_tensors(image_file, args.image_folder, image_processor, model)
189
+
190
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
191
+ keywords = [stop_str]
192
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
193
+
194
+ with torch.inference_mode():
195
+ torch.cuda.empty_cache()
196
+ if args.num_beams > 1:
197
+ output_ids = model.generate(
198
+ input_ids=input_ids,
199
+ images=image_tensors,
200
+ do_sample=False,
201
+ num_beams=args.num_beams,
202
+ no_repeat_ngram_size=3,
203
+ max_new_tokens=args.max_new_tokens,
204
+ stopping_criteria=[stopping_criteria],
205
+ use_cache=True,
206
+ length_penalty=args.length_penalty,
207
+ output_scores=True,
208
+ num_return_sequences = args.num_return_sequences,
209
+ attention_mask=attention_mask,
210
+ pad_token_id=pad_token_id)
211
+ else:
212
+ output_ids = model.generate(
213
+ input_ids,
214
+ images=image_tensors,
215
+ do_sample= True,
216
+ temperature=args.temperature,
217
+ top_p=args.top_p,
218
+ num_beams=args.num_beams,
219
+ no_repeat_ngram_size=3,
220
+ max_new_tokens=args.max_new_tokens,
221
+ stopping_criteria=[stopping_criteria],
222
+ use_cache=True,
223
+ attention_mask=attention_mask,
224
+ pad_token_id=pad_token_id)
225
+
226
+ torch.cuda.empty_cache()
227
+ input_token_len = input_ids.shape[1]
228
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
229
+
230
+ if n_diff_input_output > 0:
231
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
232
+
233
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
234
+ outputs = outputs.strip()
235
+
236
+ ans_id = shortuuid.uuid()
237
+ ans_file.write(json.dumps({"question_id": idx,
238
+ "prompt": cur_prompt,
239
+ "text": outputs,
240
+ "answer_id": ans_id,
241
+ "model_id": model_name,
242
+ "metadata": {}}) + "\n")
243
+ ans_file.flush()
244
+ ans_file.close()
245
+
246
+ if __name__ == "__main__":
247
+ parser = argparse.ArgumentParser()
248
+ parser.add_argument("--model-path", type=str, default="libra")
249
+ parser.add_argument("--model-base", type=str, default=None)
250
+ parser.add_argument("--image-folder", type=str, default="")
251
+ parser.add_argument("--question-file", type=str, default="question.jsonl")
252
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
253
+ parser.add_argument("--conv-mode", type=str, default="libra_v1")
254
+ parser.add_argument("--num-chunks", type=int, default=1)
255
+ parser.add_argument("--chunk-idx", type=int, default=0)
256
+ parser.add_argument("--temperature", type=float, default=0.2)
257
+ parser.add_argument("--top_p", type=float, default=None)
258
+ parser.add_argument("--num_beams", type=int, default=1)
259
+ parser.add_argument("--num_return_sequences", type=int, default=None)
260
+ parser.add_argument("--length_penalty", type=float, default=1.0)
261
+ parser.add_argument("--max_new_tokens", type=int, default=128)
262
+ args = parser.parse_args()
263
+
264
+ eval_model(args)
libra/eval/radiology_report.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import sys
6
+
7
+ import evaluate
8
+ import numpy as np
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+
12
+ from libra.eval import temporal_f1_score
13
+
14
+ # Pre-load metrics
15
+ bertscore_metric = evaluate.load("bertscore")
16
+ rouge_metric = evaluate.load('rouge')
17
+ bleu_metric = evaluate.load("bleu")
18
+ meteor_metric = evaluate.load('meteor')
19
+
20
+
21
+ def clean_text(text: str) -> str:
22
+ """
23
+ Perform basic cleanup of text by removing newlines, dashes, and some special patterns.
24
+ """
25
+ text = re.sub(r'\n+', ' ', text)
26
+ text = re.sub(r'[_-]+', ' ', text)
27
+ text = re.sub(r'\(___, __, __\)', '', text)
28
+ text = re.sub(r'---, ---, ---', '', text)
29
+ text = re.sub(r'\(__, __, ___\)', '', text)
30
+ text = re.sub(r'[_-]+', ' ', text)
31
+ text = re.sub(r'[^\w\s.,:;()\-]', '', text)
32
+ text = re.sub(r'\s{2,}', ' ', text).strip()
33
+ return text
34
+
35
+
36
+ def load_json(path: str) -> list:
37
+ """
38
+ Load a JSONL file and return a list of parsed objects.
39
+ Each line should be a valid JSON object.
40
+ """
41
+ content = []
42
+ with open(path, 'r', encoding='utf-8') as file:
43
+ for line in file:
44
+ content.append(json.loads(line))
45
+ return content
46
+
47
+
48
+ def extract_sections(data: list) -> list:
49
+ """
50
+ Extract relevant text sections (e.g., findings, impression, text)
51
+ from a list of JSON objects and clean each item.
52
+ """
53
+ sections_list = []
54
+ for item in data:
55
+ if 'reference' in item:
56
+ cleaned_text = clean_text(item['reference'].lower())
57
+ sections_list.append(cleaned_text)
58
+ elif 'findings' in item:
59
+ cleaned_text = clean_text(item['findings'].lower())
60
+ sections_list.append(cleaned_text)
61
+ elif 'impression' in item:
62
+ cleaned_text = clean_text(item['impression'].lower())
63
+ sections_list.append(cleaned_text)
64
+ elif 'text' in item:
65
+ cleaned_text = clean_text(item['text'].lower())
66
+ sections_list.append(cleaned_text)
67
+ return sections_list
68
+
69
+
70
+ def append_results_to_csv(results: dict, model_name: str, csv_path: str) -> None:
71
+ """
72
+ Convert the results dictionary into a DataFrame and append it to a CSV file.
73
+ Inserts 'Model Name' at the first column if it doesn't exist.
74
+ Creates a new CSV if it doesn't exist, otherwise appends.
75
+ """
76
+ df = pd.DataFrame([results])
77
+ df.insert(0, "Model Name", model_name)
78
+
79
+ header = not os.path.isfile(csv_path) # If file doesn't exist, write the header
80
+ df.to_csv(csv_path, mode='a', header=header, index=False)
81
+
82
+
83
+ def evaluate_report(
84
+ references: str,
85
+ predictions: str,
86
+ ) -> dict:
87
+ """
88
+ Evaluate the model outputs against reference texts using multiple metrics:
89
+ - BLEU (1–4)
90
+ - METEOR
91
+ - ROUGE-L
92
+ - BERTScore (F1)
93
+ - Temporal F1
94
+
95
+ Returns a dictionary of computed metrics.
96
+ """
97
+ # Load data
98
+ references_data = load_json(references)
99
+ predictions_data = load_json(predictions)
100
+
101
+ # Basic validation: question_id alignment
102
+ gt_ids = [item['question_id'] for item in references_data]
103
+ pred_ids = [item['question_id'] for item in predictions_data]
104
+ assert gt_ids == pred_ids, "Please make sure predictions and references are perfectly matched by question_id."
105
+
106
+ # Extract text sections
107
+ references_list = extract_sections(references_data)
108
+ predictions_list = extract_sections(predictions_data)
109
+
110
+ # Calculate metrics
111
+ with tqdm(total=8, desc="Calculating metrics") as pbar:
112
+ # BLEU-1
113
+ bleu1 = bleu_metric.compute(
114
+ predictions=predictions_list,
115
+ references=references_list,
116
+ max_order=1
117
+ )['bleu']
118
+ print(f"BLEU-1 Score: {round(bleu1 * 100, 2)}")
119
+ pbar.update(1)
120
+
121
+ # BLEU-2
122
+ bleu2 = bleu_metric.compute(
123
+ predictions=predictions_list,
124
+ references=references_list,
125
+ max_order=2
126
+ )['bleu']
127
+ print(f"BLEU-2 Score: {round(bleu2 * 100, 2)}")
128
+ pbar.update(1)
129
+
130
+ # BLEU-3
131
+ bleu3 = bleu_metric.compute(
132
+ predictions=predictions_list,
133
+ references=references_list,
134
+ max_order=3
135
+ )['bleu']
136
+ print(f"BLEU-3 Score: {round(bleu3 * 100, 2)}")
137
+ pbar.update(1)
138
+
139
+ # BLEU-4
140
+ bleu4 = bleu_metric.compute(
141
+ predictions=predictions_list,
142
+ references=references_list,
143
+ max_order=4
144
+ )['bleu']
145
+ print(f"BLEU-4 Score: {round(bleu4 * 100, 2)}")
146
+ pbar.update(1)
147
+
148
+ # ROUGE-L
149
+ rougel = rouge_metric.compute(
150
+ predictions=predictions_list,
151
+ references=references_list
152
+ )['rougeL']
153
+ print(f"ROUGE-L Score: {round(rougel * 100, 2)}")
154
+ pbar.update(1)
155
+
156
+ # METEOR
157
+ meteor = meteor_metric.compute(
158
+ predictions=predictions_list,
159
+ references=references_list
160
+ )['meteor']
161
+ print(f"METEOR Score: {round(meteor * 100, 2)}")
162
+ pbar.update(1)
163
+
164
+ # BERTScore (mean F1)
165
+ bert_f1 = bertscore_metric.compute(
166
+ predictions=predictions_list,
167
+ references=references_list,
168
+ model_type='distilbert-base-uncased'
169
+ )['f1']
170
+ bert_score = float(np.mean(bert_f1))
171
+ print(f"Bert Score: {round(bert_score * 100, 2)}")
172
+ pbar.update(1)
173
+
174
+ # Temporal F1
175
+ tem_f1 = temporal_f1_score(
176
+ predictions=predictions_list,
177
+ references=references_list
178
+ )["f1"]
179
+ print(f"Temporal F1 Score: {round(tem_f1 * 100, 2)}")
180
+ pbar.update(1)
181
+
182
+ return {
183
+ 'BLEU1': round(bleu1 * 100, 2),
184
+ 'BLEU2': round(bleu2 * 100, 2),
185
+ 'BLEU3': round(bleu3 * 100, 2),
186
+ 'BLEU4': round(bleu4 * 100, 2),
187
+ 'METEOR': round(meteor * 100, 2),
188
+ 'ROUGE-L': round(rougel * 100, 2),
189
+ 'Bert_score': round(bert_score * 100, 2),
190
+ 'Temporal_entity_score': round(tem_f1 * 100, 2)
191
+ }
192
+
193
+
194
+ def main():
195
+ """
196
+ Parse arguments, compute evaluation metrics, and append the results to a CSV file.
197
+ """
198
+ parser = argparse.ArgumentParser(
199
+ description='Evaluation for Libra Generated Outputs'
200
+ )
201
+ parser.add_argument('--references', type=str, required=True,
202
+ help='Path to the ground truth file (JSONL).')
203
+ parser.add_argument('--predictions', type=str, required=True,
204
+ help='Path to the prediction file (JSONL).')
205
+ parser.add_argument('--model-name', type=str, required=True,
206
+ help='Unique model identifier for tracking in the results CSV.')
207
+ parser.add_argument('--save-to-csv', type=str, required=True,
208
+ help='Path of the CSV file where results will be saved/appended.')
209
+ args = parser.parse_args()
210
+
211
+ # Calculate metrics
212
+ scores_results = evaluate_report(
213
+ references=args.references,
214
+ predictions=args.predictions
215
+ )
216
+
217
+ # Append results to CSV
218
+ append_results_to_csv(scores_results, args.model_name, args.save_to_csv)
219
+
220
+
221
+ if __name__ == "__main__":
222
+ main()
libra/eval/run_libra.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from libra.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
+ from libra.conversation import conv_templates, SeparatorStyle
6
+ from libra.model.builder import load_pretrained_model
7
+ from libra.utils import disable_torch_init
8
+ from libra.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, KeywordsStoppingCriteria
9
+
10
+ import requests
11
+ import pydicom
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from pydicom.pixel_data_handlers.util import apply_voi_lut
15
+ import datetime
16
+
17
+
18
+ def load_images(image_file):
19
+ """
20
+ Load an image from a local file, a URL, or a DICOM file.
21
+
22
+ Args:
23
+ image_file (str): The path or URL of the image file to load.
24
+
25
+ Returns:
26
+ PIL.Image.Image: The loaded image in RGB format.
27
+
28
+ Raises:
29
+ ValueError: If the DICOM file does not contain image data.
30
+ TypeError: If the input is neither a valid file path nor a URL.
31
+ """
32
+ if isinstance(image_file, str):
33
+ # Case 1: Load from URL
34
+ if image_file.startswith(('http://', 'https://')):
35
+ try:
36
+ response = requests.get(image_file)
37
+ response.raise_for_status()
38
+ image = Image.open(BytesIO(response.content)).convert('RGB')
39
+ except Exception as e:
40
+ raise ValueError(f"Error loading image from URL: {image_file}\n{e}")
41
+
42
+ # Case 2: Load from DICOM file
43
+ elif image_file.lower().endswith('.dcm'):
44
+ try:
45
+ dicom = pydicom.dcmread(image_file)
46
+ if 'PixelData' in dicom:
47
+ data = apply_voi_lut(dicom.pixel_array, dicom)
48
+
49
+ # Handle MONOCHROME1 images
50
+ if dicom.PhotometricInterpretation == "MONOCHROME1":
51
+ data = np.max(data) - data
52
+
53
+ # Normalize the image data
54
+ data = data - np.min(data)
55
+ data = data / np.max(data)
56
+ data = (data * 255).astype(np.uint8)
57
+
58
+ # Convert to 3-channel RGB if necessary
59
+ if data.ndim == 2:
60
+ data = np.stack([data] * 3, axis=-1)
61
+
62
+ image = Image.fromarray(data).convert('RGB')
63
+ else:
64
+ raise ValueError("DICOM file does not contain image data")
65
+ except Exception as e:
66
+ raise ValueError(f"Error loading DICOM file: {image_file}\n{e}")
67
+
68
+ # Case 3: Load standard image files (e.g., PNG, JPG)
69
+ else:
70
+ try:
71
+ image = Image.open(image_file).convert('RGB')
72
+ except Exception as e:
73
+ raise ValueError(f"Error loading standard image file: {image_file}\n{e}")
74
+
75
+ else:
76
+ raise TypeError("image_file must be a string representing a file path or URL")
77
+
78
+ return image
79
+
80
+ def get_image_tensors(image_path, image_processor, model, device='cuda'):
81
+ # Load and preprocess the images
82
+ if isinstance(image_path, str):
83
+ image = []
84
+ img = load_images(image_path)
85
+ image.append(img)
86
+ elif isinstance(image_path, (list, tuple)):
87
+ image = []
88
+ for path in image_path:
89
+ img = load_images(path)
90
+ image.append(img)
91
+ else:
92
+ raise TypeError("image_file must be a string or a str/list/tuple of strings")
93
+
94
+ # Ensure two images are present
95
+ if len(image) != 2:
96
+ image.append(image[0])
97
+ if model.config.mm_projector_type == "TAC":
98
+ print("Contains only current image. Adding a dummy prior image for TAC.")
99
+
100
+ # Process each image
101
+ processed_images = []
102
+ for img_data in image:
103
+ image_temp = process_images([img_data], image_processor, model.config)[0]
104
+ image_temp = image_temp.to(device=device, non_blocking=True)
105
+ processed_images.append(image_temp)
106
+
107
+ # Separate current and prior images
108
+ cur_images = [processed_images[0]]
109
+ prior_images = [processed_images[1]]
110
+
111
+ # Stack and return as batched tensor
112
+ batch_images = torch.stack([torch.stack(cur_images), torch.stack(prior_images)])
113
+
114
+ return batch_images
115
+
116
+ def libra_eval(
117
+ model_path=None,
118
+ model_base=None,
119
+ image_file=None,
120
+ query=None,
121
+ conv_mode="libra_v1",
122
+ temperature=0.2,
123
+ top_p=None,
124
+ num_beams=1,
125
+ num_return_sequences=None,
126
+ length_penalty=1.0,
127
+ max_new_tokens=128
128
+ ):
129
+ # Model
130
+ disable_torch_init()
131
+
132
+ model_name = get_model_name_from_path(model_path)
133
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name)
134
+
135
+ qs = query
136
+ if model.config.mm_use_im_start_end:
137
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
138
+ else:
139
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
140
+
141
+ if 'libra' in model_name.lower():
142
+ mode_conv = "libra_v1"
143
+
144
+ if conv_mode is not None and mode_conv != conv_mode:
145
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(mode_conv, conv_mode, conv_mode))
146
+ else:
147
+ conv_mode = mode_conv
148
+
149
+ conv = conv_templates[conv_mode].copy()
150
+ conv.append_message(conv.roles[0], qs)
151
+ conv.append_message(conv.roles[1], None)
152
+ prompt = conv.get_prompt()
153
+
154
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
155
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
156
+ pad_token_id = tokenizer.pad_token_id
157
+
158
+ image_tensor = get_image_tensors(image_file, image_processor, model)
159
+
160
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
161
+ keywords = [stop_str]
162
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
163
+
164
+ with torch.inference_mode():
165
+ torch.cuda.empty_cache()
166
+ if num_beams > 1:
167
+ output_ids = model.generate(
168
+ input_ids=input_ids,
169
+ images=image_tensor,
170
+ do_sample=False,
171
+ num_beams=num_beams,
172
+ no_repeat_ngram_size=3,
173
+ max_new_tokens=max_new_tokens,
174
+ stopping_criteria=[stopping_criteria],
175
+ use_cache=True,
176
+ length_penalty=length_penalty,
177
+ output_scores=True,
178
+ attention_mask=attention_mask,
179
+ pad_token_id=pad_token_id,
180
+ num_return_sequences = num_return_sequences)
181
+ else:
182
+ output_ids = model.generate(
183
+ input_ids,
184
+ images=image_tensor,
185
+ do_sample= True,
186
+ temperature=temperature,
187
+ top_p=top_p,
188
+ num_beams=num_beams,
189
+ no_repeat_ngram_size=3,
190
+ max_new_tokens=max_new_tokens,
191
+ attention_mask=attention_mask,
192
+ pad_token_id=pad_token_id,
193
+ stopping_criteria=[stopping_criteria],
194
+ use_cache=True)
195
+
196
+ input_token_len = input_ids.shape[1]
197
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
198
+
199
+ if n_diff_input_output > 0:
200
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
201
+
202
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
203
+ outputs = outputs.strip()
204
+
205
+ if outputs.endswith(stop_str):
206
+ outputs = outputs[:-len(stop_str)]
207
+ outputs = outputs.strip()
208
+
209
+ return outputs
210
+
211
+ if __name__ == "__main__":
212
+ parser = argparse.ArgumentParser()
213
+ parser.add_argument("--model-path", type=str, default="X-iZhang/libra-v1.0-7b")
214
+ parser.add_argument("--model-base", type=str, default=None)
215
+ parser.add_argument("--image-file", type=str, required=True)
216
+ parser.add_argument("--query", type=str, required=True)
217
+ parser.add_argument("--conv-mode", type=str, default="libra_v1")
218
+ parser.add_argument("--temperature", type=float, default=0.2)
219
+ parser.add_argument("--top_p", type=float, default=None)
220
+ parser.add_argument("--num_beams", type=int, default=1)
221
+ parser.add_argument("--num_return_sequences", type=int, default=None)
222
+ parser.add_argument("--length_penalty", type=float, default=1.0)
223
+ parser.add_argument("--max_new_tokens", type=int, default=128)
224
+ args = parser.parse_args()
225
+
226
+ libra_eval(**vars(args))
libra/eval/temporal_f1.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import argparse
3
+ from typing import List, Union
4
+
5
+ # Keywords used for entity extraction
6
+ KEYWORDS = {
7
+ "bigger", "change", "cleared", "constant", "decrease", "decreased", "decreasing", "elevated", "elevation",
8
+ "enlarged", "enlargement", "enlarging", "expanded", "greater", "growing", "improved", "improvement",
9
+ "improving", "increase", "increased", "increasing", "larger", "new", "persistence", "persistent",
10
+ "persisting", "progression", "progressive", "reduced", "removal", "resolution", "resolved", "resolving",
11
+ "smaller", "stability", "stable", "stably", "unchanged", "unfolded", "worse", "worsen", "worsened",
12
+ "worsening", "unaltered"
13
+ }
14
+
15
+ def clean_text(text: str) -> str:
16
+ """
17
+ Clean the input text by removing special characters and redundant spaces or newlines.
18
+
19
+ Args:
20
+ text (str): Input text.
21
+
22
+ Returns:
23
+ str: Cleaned text.
24
+ """
25
+ # Remove special characters and redundant newlines
26
+ text = re.sub(r'\n+', ' ', text) # Replace multiple newlines with a single space
27
+ text = re.sub(r'[_-]+', ' ', text) # Replace underscores and dashes with spaces
28
+ text = re.sub(r'\(___, __, __\)', '', text) # Remove irrelevant underscore patterns
29
+ text = re.sub(r'---, ---, ---', '', text) # Remove dashed patterns
30
+ text = re.sub(r'\(__, __, ___\)', '', text) # Remove similar underscore patterns
31
+ text = re.sub(r'[_-]+', ' ', text) # Replace underscores and dashes again (if any remain)
32
+ text = re.sub(r'[^\w\s.,:;()-]', '', text) # Remove non-alphanumeric characters except common punctuation
33
+
34
+ # Remove extra spaces
35
+ text = re.sub(r'\s{2,}', ' ', text).strip()
36
+ return text
37
+
38
+ def extract_entities(text: str, keywords: set) -> set:
39
+ """
40
+ Extract entities from the given text based on the provided keywords.
41
+
42
+ Args:
43
+ text (str): Input text.
44
+ keywords (set): Set of keywords to extract entities.
45
+
46
+ Returns:
47
+ set: Set of matched keywords found in the text.
48
+ """
49
+ # Clean the text before extracting entities
50
+ text = clean_text(text)
51
+
52
+ # Create a regex pattern that matches any of the keywords as whole words
53
+ pattern = r'\b(' + '|'.join(re.escape(word) for word in keywords) + r')\b'
54
+
55
+ # Find all matches and return them as a set
56
+ return {match.group().lower() for match in re.finditer(pattern, text.lower())}
57
+
58
+ def calculate_tem_score(prediction_text: str, reference_text: Union[str, List[str]], epsilon: float = 1e-10) -> float:
59
+ """
60
+ Calculate the Temporal Entity Matching (TEM) score (similar to F1-score).
61
+
62
+ Args:
63
+ reference_text (Union[str, List[str]]): Reference text or a list of reference texts.
64
+ prediction_text (str): Prediction text.
65
+ epsilon (float): Small value to avoid division by zero.
66
+
67
+ Returns:
68
+ float: TEM score.
69
+ """
70
+ if isinstance(reference_text, list):
71
+ reference_entities = set()
72
+ for ref in reference_text:
73
+ reference_entities.update(extract_entities(ref, KEYWORDS))
74
+ else:
75
+ reference_entities = extract_entities(reference_text, KEYWORDS)
76
+
77
+ prediction_entities = extract_entities(prediction_text, KEYWORDS)
78
+
79
+ if len(reference_entities) == 0:
80
+ if len(prediction_entities) == 0:
81
+ return {
82
+ "f1": 1.0,
83
+ "prediction_entities": prediction_entities,
84
+ "reference_entities": reference_entities
85
+ } # Perfect match when both are empty
86
+ else:
87
+ return {
88
+ "f1": epsilon,
89
+ "prediction_entities": prediction_entities,
90
+ "reference_entities": reference_entities
91
+ } # Minimal score when reference is empty but prediction is not
92
+
93
+ # Calculate intersection of entities
94
+ true_positives = len(prediction_entities & reference_entities)
95
+
96
+ # Calculate precision and recall with epsilon to avoid division by zero
97
+ precision = (true_positives + epsilon) / (len(prediction_entities) + epsilon)
98
+ recall = (true_positives + epsilon) / (len(reference_entities) + epsilon)
99
+
100
+ # Calculate TEM score (F1 score)
101
+ tem_score = (2 * precision * recall) / (precision + recall + epsilon)
102
+
103
+ return {
104
+ "f1": tem_score,
105
+ "prediction_entities": prediction_entities,
106
+ "reference_entities": reference_entities
107
+ }
108
+
109
+ def temporal_f1_score(predictions: List[str], references: List[Union[str, List[str]]], epsilon: float = 1e-10) -> float:
110
+ """
111
+ Calculate the average TEM score over a list of reference and prediction texts.
112
+
113
+ Args:
114
+ references (List[Union[str, List[str]]]): List of reference texts or lists of reference texts.
115
+ predictions (List[str]): List of prediction texts.
116
+ epsilon (float): Small value to avoid division by zero.
117
+
118
+ Returns:
119
+ float: Average TEM score.
120
+ """
121
+ assert len(references) == len(predictions), "Reference and prediction lists must have the same length."
122
+
123
+ tem_scores = []
124
+ prediction_entities = []
125
+ reference_entities = []
126
+
127
+ for pred, ref in zip(predictions, references):
128
+ result = calculate_tem_score(pred, ref, epsilon)
129
+ tem_scores.append(result["f1"])
130
+ prediction_entities.append(result["prediction_entities"])
131
+ reference_entities.append(result["reference_entities"])
132
+
133
+ average_f1 = sum(tem_scores) / len(tem_scores)
134
+
135
+ return {
136
+ "f1": average_f1,
137
+ "prediction_entities": prediction_entities,
138
+ "reference_entities": reference_entities
139
+ }
140
+
141
+ # Command-line interface
142
+ if __name__ == "__main__":
143
+ parser = argparse.ArgumentParser(description="Calculate the average TEM score for reference and prediction texts.")
144
+ parser.add_argument("--predictions", nargs='+', required=True, help="List of prediction texts.")
145
+ parser.add_argument("--reference", nargs='+', required=True, help="List of reference texts or lists of reference texts.")
146
+
147
+ args = parser.parse_args()
148
+
149
+ # Convert references into a nested list if necessary
150
+ reference_list = [eval(ref) if ref.startswith('[') else ref for ref in args.reference]
151
+
152
+ # Calculate the average TEM score
153
+ temporal_f1_score(predictions=args.predictions, references=reference_list)
libra/mm_utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from libra.constants import IMAGE_TOKEN_INDEX
8
+
9
+
10
+ def load_image_from_base64(image):
11
+ return Image.open(BytesIO(base64.b64decode(image)))
12
+
13
+ def expand2square(pil_img, background_color=(0, 0, 0)):
14
+ width, height = pil_img.size
15
+ if width == height:
16
+ return pil_img
17
+ elif width > height:
18
+ result = Image.new(pil_img.mode, (width, width), background_color)
19
+ result.paste(pil_img, (0, (width - height) // 2))
20
+ return result
21
+ else:
22
+ result = Image.new(pil_img.mode, (height, height), background_color)
23
+ result.paste(pil_img, ((height - width) // 2, 0))
24
+ return result
25
+
26
+ def process_images(images, image_processor, model_cfg):
27
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
28
+ new_images = []
29
+
30
+ try:
31
+ if not images:
32
+ raise ValueError("Input images list is empty.")
33
+
34
+ if image_aspect_ratio == 'pad':
35
+ for image in images:
36
+ if not isinstance(image, Image.Image):
37
+ raise TypeError("All input images must be of type PIL.Image.")
38
+ image = expand2square(image, (0, 0, 0))
39
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
40
+ new_images.append(image)
41
+ else:
42
+ return image_processor(images, return_tensors='pt')['pixel_values']
43
+
44
+ if new_images and all(x is not None and x.shape == new_images[0].shape for x in new_images):
45
+ new_images = torch.stack(new_images, dim=0)
46
+
47
+ return new_images
48
+ except Exception as e:
49
+ print(f"Error processing images: {e}")
50
+ return None
51
+
52
+
53
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
54
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
55
+ # 使用分词器将输入文本prompt按<image>标记分割,然后对每个分割后的文本块进行分词处理,获取对应的输入ID列表。
56
+ def insert_separator(X, sep):
57
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
58
+
59
+ input_ids = []
60
+ offset = 0
61
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
62
+ offset = 1
63
+ input_ids.append(prompt_chunks[0][0])
64
+
65
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
66
+ input_ids.extend(x[offset:])
67
+
68
+ if return_tensors is not None:
69
+ if return_tensors == 'pt':
70
+ return torch.tensor(input_ids, dtype=torch.long)
71
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
72
+ return input_ids
73
+
74
+
75
+ def get_model_name_from_path(model_path):
76
+ model_path = model_path.strip("/")
77
+ model_paths = model_path.split("/")
78
+ if model_paths[-1].startswith('checkpoint-'):
79
+ return model_paths[-2] + "_" + model_paths[-1]
80
+ else:
81
+ return model_paths[-1]
82
+
83
+
84
+ class KeywordsStoppingCriteria(StoppingCriteria):
85
+ def __init__(self, keywords, tokenizer, input_ids):
86
+ self.keywords = keywords
87
+ self.keyword_ids = []
88
+ self.max_keyword_len = 0
89
+ for keyword in keywords:
90
+ cur_keyword_ids = tokenizer(keyword).input_ids
91
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
92
+ cur_keyword_ids = cur_keyword_ids[1:]
93
+ if len(cur_keyword_ids) > self.max_keyword_len:
94
+ self.max_keyword_len = len(cur_keyword_ids)
95
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
96
+ self.tokenizer = tokenizer
97
+ self.start_len = input_ids.shape[1]
98
+
99
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
100
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
101
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
102
+ for keyword_id in self.keyword_ids:
103
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
104
+ if torch.equal(truncated_output_ids, keyword_id):
105
+ return True
106
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
107
+ for keyword in self.keywords:
108
+ if keyword in outputs:
109
+ return True
110
+ return False
111
+
112
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
113
+ outputs = []
114
+ for i in range(output_ids.shape[0]):
115
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
116
+ return all(outputs)
libra/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ try:
2
+ from .language_model.libra_llama import LibraLlamaForCausalLM, LibraConfig
3
+ except:
4
+ pass
libra/model/builder.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright 2024 Xi Zhang
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from libra.model import *
23
+ from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
27
+ kwargs = {"device_map": device_map}
28
+
29
+ if device != "cuda":
30
+ kwargs['device_map'] = {"": device}
31
+
32
+ if load_8bit:
33
+ kwargs['load_in_8bit'] = True
34
+ elif load_4bit:
35
+ kwargs['load_in_4bit'] = True
36
+ kwargs['quantization_config'] = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.float16,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type='nf4'
41
+ )
42
+ else:
43
+ kwargs['torch_dtype'] = torch.float16
44
+
45
+ if 'libra' in model_name.lower():
46
+ # Load Libra model
47
+ if 'lora' in model_name.lower() and model_base is None:
48
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.')
49
+ if 'lora' in model_name.lower() and model_base is not None:
50
+ from libra.model.language_model.libra_llama import LibraConfig
51
+ lora_cfg_pretrained = LibraConfig.from_pretrained(model_path)
52
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
53
+ print('Loading libra from base model...')
54
+ model = LibraLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
55
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
56
+ if model.lm_head.weight.shape[0] != token_num:
57
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
58
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
59
+
60
+ print('Loading additional Libra weights...')
61
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
62
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
63
+ else:
64
+ from huggingface_hub import hf_hub_download
65
+ def load_from_hf(repo_id, filename, subfolder=None):
66
+ cache_file = hf_hub_download(
67
+ repo_id=repo_id,
68
+ filename=filename,
69
+ subfolder=subfolder)
70
+ return torch.load(cache_file, map_location='cpu')
71
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
72
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
73
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
74
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
75
+ model.load_state_dict(non_lora_trainables, strict=False)
76
+
77
+ from peft import PeftModel
78
+ print('Loading LoRA weights...')
79
+ model = PeftModel.from_pretrained(model, model_path)
80
+ print('Merging LoRA weights...')
81
+ model = model.merge_and_unload()
82
+ print('Model is loaded...')
83
+ elif model_base is not None:
84
+ print('Loading Libra from base model...')
85
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
86
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
87
+ model = LibraLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
88
+
89
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
90
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
91
+ model.load_state_dict(mm_projector_weights, strict=False)
92
+ else:
93
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
94
+ model = LibraLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
95
+ else:
96
+ # Load language model
97
+ if model_base is not None:
98
+ # PEFT model
99
+ from peft import PeftModel
100
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
101
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
102
+ print(f"Loading LoRA weights from {model_path}")
103
+ model = PeftModel.from_pretrained(model, model_path)
104
+ print(f"Merging weights")
105
+ model = model.merge_and_unload()
106
+ print('Convert to FP16...')
107
+ model.to(torch.float16)
108
+ else:
109
+ use_fast = False
110
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
111
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
112
+
113
+ image_processor = None
114
+
115
+ if 'libra' in model_name.lower():
116
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
117
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
118
+ if mm_use_im_patch_token:
119
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
120
+ if mm_use_im_start_end:
121
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
122
+ model.resize_token_embeddings(len(tokenizer))
123
+
124
+ vision_tower = model.get_vision_tower()
125
+ if not vision_tower.is_loaded:
126
+ vision_tower.load_model()
127
+ vision_tower.to(device=device, dtype=torch.float16)
128
+ image_processor = vision_tower.image_processor
129
+
130
+ if hasattr(model.config, "max_sequence_length"):
131
+ context_len = model.config.max_sequence_length
132
+ else:
133
+ context_len = 2048
134
+
135
+ return tokenizer, model, image_processor, context_len
libra/model/language_model/libra_llama.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xi Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+
25
+ from ..libra_arch import LibraMetaModel, LibraMetaForCausalLM
26
+
27
+
28
+ class LibraConfig(LlamaConfig):
29
+ model_type = "libra"
30
+
31
+ class LibraLlamaModel(LibraMetaModel, LlamaModel):
32
+ config_class = LibraConfig
33
+
34
+ def __init__(self, config: LlamaConfig):
35
+ super(LibraLlamaModel, self).__init__(config)
36
+
37
+
38
+ class LibraLlamaForCausalLM(LlamaForCausalLM, LibraMetaForCausalLM):
39
+ config_class = LibraConfig
40
+
41
+ def __init__(self, config):
42
+ super(LlamaForCausalLM, self).__init__(config)
43
+ self.model = LibraLlamaModel(config)
44
+ self.vocab_size = config.vocab_size
45
+
46
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
47
+
48
+ # Initialize weights and apply final processing
49
+ self.post_init()
50
+
51
+ def get_model(self):
52
+ return self.model
53
+
54
+ def forward(
55
+ self,
56
+ input_ids: torch.LongTensor = None,
57
+ attention_mask: Optional[torch.Tensor] = None,
58
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
59
+ inputs_embeds: Optional[torch.FloatTensor] = None,
60
+ labels: Optional[torch.LongTensor] = None,
61
+ use_cache: Optional[bool] = None,
62
+ output_attentions: Optional[bool] = None,
63
+ output_hidden_states: Optional[bool] = None,
64
+ images: Optional[torch.FloatTensor] = None,
65
+ return_dict: Optional[bool] = None,
66
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
67
+
68
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
69
+ output_hidden_states = (
70
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
71
+ )
72
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
73
+
74
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
75
+
76
+ outputs = self.model(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ past_key_values=past_key_values,
80
+ inputs_embeds=inputs_embeds,
81
+ use_cache=use_cache,
82
+ output_attentions=output_attentions,
83
+ output_hidden_states=output_hidden_states,
84
+ return_dict=return_dict
85
+ )
86
+
87
+ hidden_states = outputs[0]
88
+ logits = self.lm_head(hidden_states)
89
+
90
+ loss = None
91
+ #Adopted from https://github.com/huggingface/transformers/blob/v4.21.0/src/transformers/models/gptj/modeling_gptj.py#L847
92
+ if labels is not None:
93
+ # Shift so that tokens < n predict n
94
+ shift_logits = logits[..., :-1, :].contiguous()
95
+ shift_labels = labels[..., 1:].contiguous()
96
+ # Flatten the tokens
97
+ loss_fct = CrossEntropyLoss()
98
+ shift_logits=shift_logits.view(-1, shift_logits.size(-1))
99
+ shift_labels = shift_labels.view(-1)
100
+ # Enable model/pipeline parallelism
101
+ shift_labels = shift_labels.to(shift_logits.device)
102
+
103
+ loss = loss_fct(shift_logits, shift_labels)
104
+
105
+ if not return_dict:
106
+ output = (logits,) + outputs[1:]
107
+ return ((loss,) + output) if loss is not None else output
108
+
109
+ return CausalLMOutputWithPast(
110
+ loss=loss,
111
+ logits=logits,
112
+ past_key_values=outputs.past_key_values,
113
+ hidden_states=outputs.hidden_states,
114
+ attentions=outputs.attentions,
115
+ )
116
+
117
+ def prepare_inputs_for_generation(
118
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
+ ):
120
+ if past_key_values:
121
+ input_ids = input_ids[:, -1:]
122
+
123
+
124
+ if inputs_embeds is not None and past_key_values is None:
125
+ model_inputs = {"inputs_embeds": inputs_embeds}
126
+ else:
127
+ model_inputs = {"input_ids": input_ids}
128
+
129
+ model_inputs.update(
130
+ {
131
+ "past_key_values": past_key_values,
132
+ "use_cache": kwargs.get("use_cache"),
133
+ "attention_mask": attention_mask,
134
+ "images": kwargs.get("images", None),
135
+ }
136
+ )
137
+
138
+ return model_inputs
139
+
140
+ AutoConfig.register("libra", LibraConfig) # Register the LibraConfig to the AutoConfig registry
141
+ AutoModelForCausalLM.register(LibraConfig, LibraLlamaForCausalLM) # Register the LibraLlamaForCausalLM to the AutoModel registry
libra/model/libra_arch.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xi Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from .multimodal_encoder.builder import build_vision_tower
21
+ from .multimodal_projector.builder import build_vision_projector
22
+
23
+ from libra.constants import (
24
+ IGNORE_INDEX,
25
+ IMAGE_TOKEN_INDEX,
26
+ DEFAULT_IMAGE_PATCH_TOKEN,
27
+ DEFAULT_IM_START_TOKEN,
28
+ DEFAULT_IM_END_TOKEN,
29
+ )
30
+
31
+
32
+ class LibraMetaModel:
33
+ """
34
+ LibraMetaModel is a class that initializes and manages a multi-modal model with vision and projection modules.
35
+
36
+ Attributes:
37
+ config (object): Configuration object containing model parameters.
38
+ vision_tower (object): Vision model component.
39
+ mm_projector (object): Multi-modal projection module.
40
+
41
+ Methods:
42
+ __init__(config):
43
+ Initializes the LibraMetaModel with the given configuration.
44
+
45
+ get_vision_tower():
46
+ Retrieves the vision model component. If the vision model is a list, returns the first element.
47
+
48
+ initialize_vision_modules(model_args, fsdp=None):
49
+ Initializes the vision and projection modules based on the provided model arguments.
50
+ Loads pre-trained weights for the multi-modal MLP adapter if available.
51
+ """
52
+ def __init__(self, config):
53
+
54
+ super(LibraMetaModel, self).__init__(config)
55
+
56
+ if hasattr(config, "mm_vision_tower"):
57
+
58
+ self.vision_tower = build_vision_tower(config, delay_load=True)
59
+ self.mm_projector = build_vision_projector(config)
60
+
61
+ def get_vision_tower(self):
62
+ vision_tower = getattr(self, 'vision_tower', None)
63
+ if type(vision_tower) is list:
64
+ vision_tower = vision_tower[0]
65
+ return vision_tower
66
+
67
+ def initialize_vision_modules(self, model_args, fsdp=None):
68
+ vision_tower = model_args.vision_tower
69
+ mm_vision_select_layer = model_args.mm_vision_select_layer
70
+ mm_vision_select_feature = model_args.mm_vision_select_feature
71
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
72
+
73
+ self.config.mm_vision_tower = vision_tower
74
+
75
+ if self.get_vision_tower() is None:
76
+ vision_tower = build_vision_tower(model_args)
77
+
78
+ if fsdp is not None and len(fsdp) > 0:
79
+ self.vision_tower = [vision_tower]
80
+ else:
81
+ self.vision_tower = vision_tower
82
+ else:
83
+ if fsdp is not None and len(fsdp) > 0:
84
+ vision_tower = self.vision_tower[0]
85
+ else:
86
+ vision_tower = self.vision_tower
87
+ vision_tower.load_model()
88
+
89
+ self.config.use_mm_proj = True
90
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
91
+ self.config.mm_hidden_size = vision_tower.hidden_size
92
+ self.config.mm_vision_select_layer = mm_vision_select_layer
93
+ self.config.mm_vision_select_feature = mm_vision_select_feature
94
+
95
+ if getattr(self, 'mm_projector', None) is None:
96
+ self.mm_projector = build_vision_projector(self.config)
97
+ else:
98
+ for p in self.mm_projector.parameters():
99
+ p.requires_grad = True
100
+
101
+ if pretrain_mm_mlp_adapter is not None:
102
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
103
+
104
+ def get_w(weights, keyword):
105
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
106
+
107
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
108
+
109
+
110
+ class LibraMetaForCausalLM(ABC):
111
+
112
+ @abstractmethod
113
+ def get_model(self):
114
+ pass
115
+
116
+ def get_vision_tower(self):
117
+ return self.get_model().get_vision_tower()
118
+
119
+ def encode_images(self, images):
120
+ image_features_temp = self.get_model().get_vision_tower()(images)
121
+ image_features = self.get_model().mm_projector(image_features_temp)
122
+
123
+ return image_features
124
+
125
+ def prepare_inputs_labels_for_multimodal(
126
+ self, input_ids, attention_mask, past_key_values, labels, images
127
+ ):
128
+ """
129
+ Prepare inputs and labels for multimodal tasks, applying different logic based on training or inference phase.
130
+
131
+ Args:
132
+ input_ids (Tensor): IDs of the input token sequence.
133
+ attention_mask (Tensor): Attention mask.
134
+ past_key_values (Tensor): Cached key and value for attention mechanism.
135
+ labels (Tensor): Target labels.
136
+ images (Tensor): Image inputs.
137
+
138
+ Returns:
139
+ Tuple: Processed input_ids, attention_mask, past_key_values, multimodal_features, labels
140
+ """
141
+
142
+ vision_tower = self.get_vision_tower()
143
+
144
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
145
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
146
+ attention_mask = torch.ones(
147
+ (attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
148
+ dtype=attention_mask.dtype,
149
+ device=attention_mask.device
150
+ )
151
+ return input_ids, attention_mask, past_key_values, None, labels
152
+
153
+ if input_ids.size(0) != images.size(0) and input_ids.size(0) != images.size(1):
154
+ # print(
155
+ # "Warning: Dimension mismatch detected. Adjust dimensions for beam-search.\n"
156
+ # "Program continues..."
157
+ # )
158
+ num_groups = input_ids.size(0)
159
+ images_1 = images[:num_groups]
160
+ images_2 = images[num_groups:]
161
+ images = torch.cat((images_1, images_2), dim=1)
162
+ images = images.permute(1, 0, 2, 3, 4).contiguous()
163
+
164
+ image_features = self.encode_images(images)
165
+
166
+ new_input_embeds = []
167
+ new_labels = [] if labels is not None else None
168
+ cur_image_idx = 0
169
+ for batch_idx, cur_input_ids in enumerate(input_ids):
170
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
171
+
172
+ cur_image_features = image_features[cur_image_idx]
173
+ cur_input_embeds_temp = self.get_model().embed_tokens(cur_input_ids)
174
+ cur_input_embeds = torch.cat([cur_input_embeds_temp, cur_image_features[0:0]], dim=0)
175
+
176
+ new_input_embeds.append(cur_input_embeds)
177
+ if labels is not None:
178
+ new_labels.append(labels[batch_idx])
179
+ cur_image_idx += 1
180
+
181
+ continue
182
+
183
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
184
+ cur_new_input_embeds = []
185
+
186
+ if labels is not None:
187
+ cur_labels = labels[batch_idx]
188
+ cur_new_labels = []
189
+
190
+ assert cur_labels.shape == cur_input_ids.shape
191
+
192
+ while image_token_indices.numel() > 0:
193
+
194
+ cur_image_features = image_features[cur_image_idx]
195
+ image_token_start = image_token_indices[0]
196
+
197
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
198
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
199
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
200
+ cur_new_input_embeds.append(cur_image_features)
201
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
202
+ if labels is not None:
203
+ cur_new_labels.append(cur_labels[:image_token_start])
204
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
205
+ cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
206
+ cur_labels_temp = cur_labels[image_token_start+2:]
207
+ else:
208
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
209
+ cur_new_input_embeds.append(cur_image_features)
210
+
211
+ if labels is not None:
212
+ cur_new_labels.append(cur_labels[:image_token_start])
213
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
214
+ cur_labels_temp = cur_labels[image_token_start+1:]
215
+
216
+ cur_image_idx += 1
217
+
218
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
219
+ cur_input_ids = cur_input_ids[image_token_start+2:]
220
+ else:
221
+ cur_input_ids = cur_input_ids[image_token_start+1:]
222
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
223
+
224
+ if cur_input_ids.numel() > 0:
225
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
226
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
227
+ else:
228
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
229
+ if labels is not None:
230
+ cur_new_labels.append(cur_labels_temp)
231
+
232
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
233
+
234
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
235
+
236
+ new_input_embeds.append(cur_new_input_embeds)
237
+
238
+ if labels is not None:
239
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
240
+ new_labels.append(cur_new_labels)
241
+
242
+
243
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
244
+ max_len = max(x.shape[0] for x in new_input_embeds)
245
+
246
+ new_input_embeds_align = []
247
+ for cur_new_embed in new_input_embeds:
248
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
249
+ new_input_embeds_align.append(cur_new_embed)
250
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
251
+
252
+ if labels is not None:
253
+ new_labels_align = []
254
+ _new_labels = new_labels
255
+ for cur_new_label in new_labels:
256
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
257
+ new_labels_align.append(cur_new_label)
258
+ new_labels = torch.stack(new_labels_align, dim=0)
259
+
260
+ if attention_mask is not None:
261
+ new_attention_mask = []
262
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
263
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
264
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
265
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
266
+ new_attention_mask.append(cur_new_attention_mask)
267
+ attention_mask = torch.stack(new_attention_mask, dim=0)
268
+ assert attention_mask.shape == new_labels.shape
269
+ else:
270
+
271
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
272
+
273
+ if labels is not None:
274
+ new_labels = torch.stack(new_labels, dim=0)
275
+
276
+ if attention_mask is not None:
277
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
278
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
279
+
280
+ assert attention_mask.shape == new_input_embeds.shape[:2]
281
+
282
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
283
+
284
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
285
+
286
+ if model_args.mm_use_im_patch_token:
287
+
288
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
289
+ self.resize_token_embeddings(len(tokenizer))
290
+
291
+ if model_args.mm_use_im_start_end:
292
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
293
+ self.resize_token_embeddings(len(tokenizer))
294
+
295
+ if num_new_tokens > 0:
296
+ input_embeddings = self.get_input_embeddings().weight.data
297
+ output_embeddings = self.get_output_embeddings().weight.data
298
+
299
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
300
+ dim=0, keepdim=True)
301
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
302
+ dim=0, keepdim=True)
303
+
304
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
305
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
306
+
307
+ if model_args.tune_mm_mlp_adapter:
308
+ for p in self.get_input_embeddings().parameters():
309
+ p.requires_grad = True
310
+ for p in self.get_output_embeddings().parameters():
311
+ p.requires_grad = False
312
+
313
+ if model_args.pretrain_mm_mlp_adapter:
314
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
315
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
316
+ assert num_new_tokens == 2
317
+ if input_embeddings.shape == embed_tokens_weight.shape:
318
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
319
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
320
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
321
+ else:
322
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
323
+
324
+
325
+ elif model_args.mm_use_im_patch_token:
326
+
327
+ if model_args.tune_mm_mlp_adapter:
328
+
329
+ for p in self.get_input_embeddings().parameters():
330
+ p.requires_grad = False
331
+
332
+ for p in self.get_output_embeddings().parameters():
333
+ p.requires_grad = False
libra/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xi Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from .clip_encoder import CLIPVisionTower
17
+ from .dino_encoder import DINOVisionTower
18
+
19
+ def build_vision_tower(vision_tower_cfg, **kwargs):
20
+
21
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
22
+
23
+ if vision_tower is None:
24
+ raise ValueError("No vision tower specified in configuration.")
25
+
26
+ is_absolute_path_exists = os.path.exists(vision_tower)
27
+
28
+ if is_absolute_path_exists or vision_tower.startswith("openai") or \
29
+ vision_tower.startswith("facebook") or vision_tower.startswith("microsoft"):
30
+
31
+ if "clip" in vision_tower.lower():
32
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
33
+ elif "dino" in vision_tower.lower():
34
+ return DINOVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
35
+ else:
36
+ raise ValueError(f'Unknown vision model type in vision_tower: {vision_tower}')
37
+
38
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
libra/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xi Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from transformers import AutoImageProcessor, AutoModel, AutoConfig
19
+
20
+ class CLIPVisionTower(nn.Module):
21
+ def __init__(self, vision_tower, args, delay_load=False):
22
+ super().__init__()
23
+
24
+ self.is_loaded = False
25
+
26
+ self.vision_tower_name = vision_tower
27
+ self.select_layer = args.mm_vision_select_layer
28
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
29
+
30
+ if not delay_load:
31
+ self.load_model()
32
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
33
+ self.load_model()
34
+ else:
35
+ self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
36
+
37
+ def load_model(self):
38
+ if self.is_loaded:
39
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
40
+ return
41
+
42
+ self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name)
43
+ self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name)
44
+ self.vision_tower.requires_grad_(False)
45
+
46
+ self.is_loaded = True
47
+
48
+ def get_features(self, images):
49
+ outputs = self.vision_tower(images, output_hidden_states=True)
50
+ hidden_states = outputs.hidden_states
51
+
52
+ if self.select_layer == "all":
53
+ if self.select_feature == "patch":
54
+ all_layers_features = [hidden_state[:, 1:, :].contiguous() for hidden_state in hidden_states[1:]]
55
+ elif self.select_feature == "cls_patch":
56
+ all_layers_features = [hidden_state.contiguous() for hidden_state in hidden_states[1:]]
57
+ else:
58
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
59
+
60
+ return torch.stack(all_layers_features)
61
+ else:
62
+ selected_layer_features = hidden_states[int(self.select_layer)]
63
+
64
+ if self.select_feature == "patch":
65
+ selected_layer_features = selected_layer_features[:, 1:]
66
+ elif self.select_feature == "cls_patch":
67
+ selected_layer_features = selected_layer_features
68
+ else:
69
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
70
+
71
+ return selected_layer_features
72
+
73
+ @torch.no_grad()
74
+ def forward(self, images):
75
+
76
+ if images.shape[0] != 2:
77
+ raise ValueError(
78
+ f"Expected images.shape[0] == 2, but got {images.shape[0]}. "
79
+ "Ensure the input includes both current and previous images."
80
+ )
81
+
82
+ cur_images = images[0]
83
+ prev_images = images[1]
84
+
85
+ cur_features = self.get_features(cur_images)
86
+ prev_features = self.get_features(prev_images)
87
+
88
+ cur_features = cur_features.permute(1, 0, 2, 3)
89
+ prev_features = prev_features.permute(1, 0, 2, 3)
90
+
91
+ # Stack current and previous images along a new dimension
92
+ images_features = torch.stack([cur_features, prev_features])
93
+
94
+ return images_features
95
+
96
+ @property
97
+ def dummy_feature(self):
98
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
99
+
100
+ @property
101
+ def dtype(self):
102
+
103
+ return self.vision_tower.dtype
104
+
105
+ @property
106
+ def device(self):
107
+ return self.vision_tower.device
108
+
109
+ @property
110
+ def config(self):
111
+ if self.is_loaded:
112
+ return self.vision_tower.config
113
+ else:
114
+ return self.cfg_only
115
+
116
+ @property
117
+ def hidden_size(self):
118
+ return self.config.hidden_size
119
+
120
+ @property
121
+ def num_patches(self):
122
+ return (self.config.image_size // self.config.patch_size) ** 2
123
+
124
+ @property
125
+ def num_layers(self):
126
+ return self.config.num_hidden_layers
libra/model/multimodal_encoder/dino_encoder.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xi Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from transformers import AutoImageProcessor, AutoModel, AutoConfig
19
+
20
+ class DINOVisionTower(nn.Module):
21
+ def __init__(self, vision_tower, args, delay_load=False):
22
+ super().__init__()
23
+
24
+ self.is_loaded = False
25
+
26
+ self.vision_tower_name = vision_tower
27
+ self.select_layer = args.mm_vision_select_layer
28
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
29
+
30
+ if not delay_load:
31
+ self.load_model()
32
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
33
+ self.load_model()
34
+ else:
35
+ self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
36
+
37
+ def load_model(self):
38
+ if self.is_loaded:
39
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
40
+ return
41
+
42
+ self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name)
43
+ self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name)
44
+ self.vision_tower.requires_grad_(False)
45
+
46
+ self.is_loaded = True
47
+
48
+ def get_features(self, images):
49
+ outputs = self.vision_tower(images, output_hidden_states=True)
50
+ hidden_states = outputs.hidden_states
51
+
52
+ if self.select_layer == "all":
53
+ if self.select_feature == "patch":
54
+ all_layers_features = [hidden_state[:, 1:, :].contiguous() for hidden_state in hidden_states[1:]]
55
+ elif self.select_feature == "cls_patch":
56
+ all_layers_features = [hidden_state.contiguous() for hidden_state in hidden_states[1:]]
57
+ else:
58
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
59
+
60
+ return torch.stack(all_layers_features)
61
+ else:
62
+ selected_layer_features = hidden_states[int(self.select_layer)]
63
+
64
+ if self.select_feature == "patch":
65
+ selected_layer_features = selected_layer_features[:, 1:]
66
+ elif self.select_feature == "cls_patch":
67
+ selected_layer_features = selected_layer_features
68
+ else:
69
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
70
+
71
+ return torch.stack([selected_layer_features])
72
+
73
+ @torch.no_grad()
74
+ def forward(self, images):
75
+
76
+ if images.shape[0] != 2:
77
+ raise ValueError(
78
+ f"Expected images.shape[0] == 2, but got {images.shape}. "
79
+ "Ensure the input includes both current and previous images."
80
+ )
81
+
82
+ cur_images = images[0]
83
+ prev_images = images[1]
84
+
85
+ cur_features = self.get_features(cur_images)
86
+ prev_features = self.get_features(prev_images)
87
+
88
+ cur_features = cur_features.permute(1, 0, 2, 3)
89
+ prev_features = prev_features.permute(1, 0, 2, 3)
90
+
91
+ # Stack current and previous images along a new dimension
92
+ images_features = torch.stack([cur_features, prev_features])
93
+
94
+ return images_features
95
+
96
+ @property
97
+ def dummy_feature(self):
98
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
99
+
100
+ @property
101
+ def dtype(self):
102
+
103
+ return self.vision_tower.dtype
104
+
105
+ @property
106
+ def device(self):
107
+ return self.vision_tower.device
108
+
109
+ @property
110
+ def config(self):
111
+ if self.is_loaded:
112
+ return self.vision_tower.config
113
+ else:
114
+ return self.cfg_only
115
+
116
+ @property
117
+ def hidden_size(self):
118
+ return self.config.hidden_size
119
+
120
+ @property
121
+ def num_patches(self):
122
+ return (self.config.image_size // self.config.patch_size) ** 2
123
+
124
+ @property
125
+ def num_layers(self):
126
+ return self.config.num_hidden_layers
libra/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xi Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torchvision.ops as ops
18
+ import re
19
+
20
+
21
+ class TAC(nn.Module):
22
+ def __init__(self, config):
23
+ super(TAC,self).__init__()
24
+
25
+ self.mm_hidden_size = config.mm_hidden_size
26
+ self.hidden_size = config.hidden_size
27
+ self.num_attention_heads = config.num_attention_heads
28
+ self.dropout = 0.1
29
+ self.layers_number = 12 # RAD-DINO hidden layers
30
+
31
+ # LFE
32
+ self.LFE = nn.Sequential(
33
+ ops.SqueezeExcitation(self.layers_number,self.layers_number // 2,activation=nn.GELU),
34
+ nn.Conv2d(self.layers_number,self.layers_number // 2,kernel_size=1,bias=False),
35
+ ops.SqueezeExcitation(self.layers_number // 2,self.layers_number // 4,activation=nn.GELU),
36
+ nn.Conv2d(self.layers_number // 2,self.layers_number // 4,kernel_size=1,bias=False),
37
+ ops.SqueezeExcitation(self.layers_number // 4,1,activation=nn.GELU),
38
+ nn.Conv2d(self.layers_number // 4,1,kernel_size=1,bias=False)
39
+ )
40
+
41
+ self.LFE_prior_bias = nn.Parameter(torch.tensor(0.0, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")))
42
+ self.LFE_cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
43
+
44
+ # self-attention
45
+ self.cur_self_attention = nn.MultiheadAttention(embed_dim=(self.mm_hidden_size), num_heads=self.num_attention_heads,batch_first=True,add_bias_kv=True)
46
+ self.prior_self_attention = nn.MultiheadAttention(embed_dim=(self.mm_hidden_size), num_heads=self.num_attention_heads,batch_first=True,add_bias_kv=True)
47
+ self.cros_attention = nn.MultiheadAttention(embed_dim=(self.mm_hidden_size), num_heads=self.num_attention_heads,batch_first=True,add_bias_kv=True)
48
+
49
+ self.norm1 = nn.LayerNorm(self.mm_hidden_size)
50
+ self.norm2 = nn.LayerNorm(self.mm_hidden_size)
51
+ self.norm3 = nn.LayerNorm(self.mm_hidden_size)
52
+ self.norm4 = nn.LayerNorm(self.mm_hidden_size)
53
+
54
+ self.mlp_attn = nn.Sequential(
55
+ nn.Linear(self.mm_hidden_size, self.mm_hidden_size),
56
+ nn.GELU(),
57
+ nn.Dropout(self.dropout),
58
+ nn.Linear(self.mm_hidden_size, self.mm_hidden_size),
59
+ nn.Dropout(self.dropout)
60
+ )
61
+
62
+ self.mlp_final = nn.Sequential(
63
+ nn.Linear(self.mm_hidden_size, self.hidden_size),
64
+ nn.GELU(),
65
+ nn.Linear(self.hidden_size, self.hidden_size),
66
+ nn.GELU(),
67
+ nn.Linear(self.hidden_size, self.hidden_size),
68
+ nn.GELU(),
69
+ nn.Linear(self.hidden_size, self.hidden_size)
70
+ )
71
+
72
+ self.dropout1 = nn.Dropout(self.dropout)
73
+ self.dropout2 = nn.Dropout(self.dropout)
74
+ self.dropout3 = nn.Dropout(self.dropout)
75
+
76
+ def calculate_cosine_similarity(self, tensor1, tensor2):
77
+
78
+ assert tensor1.shape == tensor2.shape, "The shapes of the two tensors must be the same"
79
+
80
+ tensor1_flat = tensor1.view(tensor1.size(0), -1)
81
+ tensor2_flat = tensor2.view(tensor2.size(0), -1)
82
+
83
+ tensor1_flat_normalized = tensor1_flat / tensor1_flat.norm(dim=-1, keepdim=True)
84
+ tensor2_flat_normalized = tensor2_flat / tensor2_flat.norm(dim=-1, keepdim=True)
85
+
86
+ cosine_similarities = self.LFE_cos(tensor1_flat_normalized, tensor2_flat_normalized)
87
+ cosine_similarities_normalized = ((cosine_similarities + 1) / 2).pow(8)
88
+ cosine_similarities_normalized = cosine_similarities_normalized.view(-1, 1, 1)
89
+
90
+ return cosine_similarities_normalized
91
+
92
+ # self-attention block
93
+ def cur_self_att_block(self,x):
94
+ x = self.cur_self_attention(x,x,x)[0]
95
+ return self.dropout1(x)
96
+ # self-attention block
97
+ def prior_self_att_block(self,x):
98
+ x = self.prior_self_attention(x,x,x)[0]
99
+ return self.dropout2(x)
100
+ # cross attention block
101
+ def cros_att_block(self,x,y):
102
+ x = self.cros_attention(x,y,y)[0]
103
+ return self.dropout3(x)
104
+
105
+ #TFM
106
+ def TFM(self,cur_features,prev_features):
107
+
108
+ cur_features_temp = cur_features
109
+ prev_features_temp = prev_features
110
+
111
+ cos= self.calculate_cosine_similarity(cur_features_temp,prev_features_temp)
112
+ prev_weight = cos * self.LFE_prior_bias
113
+ prev_features_temp = prev_features_temp + prev_weight
114
+
115
+ cur_features = self.norm1(cur_features_temp + self.cur_self_att_block(cur_features_temp))
116
+ prev_features = self.norm2(prev_features_temp + self.prior_self_att_block(prev_features_temp))
117
+ combined_features = self.norm3(cur_features + self.cros_att_block(cur_features,prev_features))
118
+
119
+ output = self.norm4(cur_features_temp + self.mlp_attn(combined_features))
120
+ output = self.mlp_final(output)
121
+
122
+ return output
123
+
124
+ def forward(self, image_features, *args, **kwargs):
125
+ cur_features, prev_features = image_features
126
+
127
+ cur_features = self.LFE(cur_features).squeeze(1)
128
+ prev_features= self.LFE(prev_features).squeeze(1)
129
+
130
+ output = self.TFM(cur_features,prev_features)
131
+
132
+ return output
133
+
134
+ @property
135
+ def config(self):
136
+ return {"mm_projector_type": 'TAC'}
137
+
138
+ class Projector(nn.Module):
139
+ def __init__(self, base_projector):
140
+ super().__init__()
141
+ self.projector = base_projector
142
+
143
+ def forward(self, image_features, *args, **kwargs):
144
+ temp_features = image_features[0].squeeze(1)
145
+ return self.projector(temp_features)
146
+
147
+
148
+ def build_vision_projector(config, delay_load=False, *args,**kwargs):
149
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
150
+
151
+ if projector_type == 'linear':
152
+ linear_layer = nn.Linear(config.mm_hidden_size, config.hidden_size)
153
+ return Projector(linear_layer)
154
+
155
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
156
+ if mlp_gelu_match:
157
+ mlp_depth = int(mlp_gelu_match.group(1))
158
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
159
+ for _ in range(1, mlp_depth):
160
+ modules.append(nn.GELU())
161
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
162
+ return Projector(nn.Sequential(*modules))
163
+
164
+ if projector_type == 'TAC':
165
+ return TAC(config)
166
+
167
+ raise ValueError(f'Unknown projector type: {projector_type}')
libra/serve/__init__.py ADDED
File without changes
libra/serve/cli.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from libra.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
+ from libra.conversation import conv_templates, SeparatorStyle
6
+ from libra.model.builder import load_pretrained_model
7
+ from libra.utils import disable_torch_init
8
+ from libra.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9
+
10
+ import requests
11
+ import pydicom
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from pydicom.pixel_data_handlers.util import apply_voi_lut
15
+ from transformers import TextStreamer
16
+
17
+
18
+ def load_images(image_file):
19
+ """
20
+ Load an image from a local file, a URL, or a DICOM file.
21
+
22
+ Args:
23
+ image_file (str): The path or URL of the image file to load.
24
+
25
+ Returns:
26
+ PIL.Image.Image: The loaded image in RGB format.
27
+
28
+ Raises:
29
+ ValueError: If the DICOM file does not contain image data.
30
+ TypeError: If the input is neither a valid file path nor a URL.
31
+ """
32
+ if isinstance(image_file, str):
33
+ # Case 1: Load from URL
34
+ if image_file.startswith(('http://', 'https://')):
35
+ try:
36
+ response = requests.get(image_file)
37
+ response.raise_for_status()
38
+ image = Image.open(BytesIO(response.content)).convert('RGB')
39
+ except Exception as e:
40
+ raise ValueError(f"Error loading image from URL: {image_file}\n{e}")
41
+
42
+ # Case 2: Load from DICOM file
43
+ elif image_file.lower().endswith('.dcm'):
44
+ try:
45
+ dicom = pydicom.dcmread(image_file)
46
+ if 'PixelData' in dicom:
47
+ data = apply_voi_lut(dicom.pixel_array, dicom)
48
+
49
+ # Handle MONOCHROME1 images
50
+ if dicom.PhotometricInterpretation == "MONOCHROME1":
51
+ data = np.max(data) - data
52
+
53
+ # Normalize the image data
54
+ data = data - np.min(data)
55
+ data = data / np.max(data)
56
+ data = (data * 255).astype(np.uint8)
57
+
58
+ # Convert to 3-channel RGB if necessary
59
+ if data.ndim == 2:
60
+ data = np.stack([data] * 3, axis=-1)
61
+
62
+ image = Image.fromarray(data).convert('RGB')
63
+ else:
64
+ raise ValueError("DICOM file does not contain image data")
65
+ except Exception as e:
66
+ raise ValueError(f"Error loading DICOM file: {image_file}\n{e}")
67
+
68
+ # Case 3: Load standard image files (e.g., PNG, JPG)
69
+ else:
70
+ try:
71
+ image = Image.open(image_file).convert('RGB')
72
+ except Exception as e:
73
+ raise ValueError(f"Error loading standard image file: {image_file}\n{e}")
74
+
75
+ else:
76
+ raise TypeError("image_file must be a string representing a file path or URL")
77
+
78
+ return image
79
+
80
+ def main(args):
81
+ """
82
+ Main function to load a pretrained model, process images, and interact with the user through a conversation loop.
83
+ Args:
84
+ args (Namespace): A namespace object containing the following attributes:
85
+ model_path (str): Path to the pretrained model.
86
+ model_base (str): Base model name.
87
+ load_8bit (bool): Flag to load the model in 8-bit precision.
88
+ load_4bit (bool): Flag to load the model in 4-bit precision.
89
+ device (str): Device to load the model on (e.g., 'cuda', 'cpu').
90
+ conv_mode (str, optional): Conversation mode to use. If None, it will be inferred from the model name.
91
+ image_file (list): List of paths to image files to be processed.
92
+ temperature (float): Sampling temperature for text generation.
93
+ max_new_tokens (int): Maximum number of new tokens to generate.
94
+ debug (bool): Flag to enable debug mode for additional output.
95
+ Raises:
96
+ EOFError: If an EOFError is encountered during user input, the loop will exit.
97
+ """
98
+ # Model
99
+ disable_torch_init()
100
+
101
+ model_name = get_model_name_from_path(args.model_path)
102
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
103
+
104
+ if 'libra' in model_name.lower():
105
+ conv_mode = "libra_v1"
106
+
107
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
108
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
109
+ else:
110
+ args.conv_mode = conv_mode
111
+
112
+ conv = conv_templates[args.conv_mode].copy()
113
+ roles = conv.roles
114
+
115
+ image=[]
116
+ for path in args.image_file:
117
+ img = load_images(path)
118
+ image.append(img)
119
+
120
+ # set dummy prior image
121
+ if len(image) == 1:
122
+ print("Contains only current image. Adding a dummy prior image.")
123
+ image.append(image[0])
124
+
125
+ processed_images = []
126
+ for img_data in image:
127
+ image_temp = process_images([img_data], image_processor, model.config)[0]
128
+ image_temp = image_temp.to(device='cuda',non_blocking=True)
129
+ processed_images.append(image_temp)
130
+
131
+ cur_images = [processed_images[0]]
132
+ prior_images = [processed_images[1]]
133
+ image_tensor = torch.stack([torch.stack(cur_images), torch.stack(prior_images)])
134
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
135
+
136
+ while True:
137
+ try:
138
+ inp = input(f"{roles[0]}: ")
139
+ except EOFError:
140
+ inp = ""
141
+ if not inp:
142
+ print("exit...")
143
+ break
144
+
145
+ print(f"{roles[1]}: ", end="")
146
+
147
+ if image is not None:
148
+ # first message
149
+ if model.config.mm_use_im_start_end:
150
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
151
+ else:
152
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
153
+ image = None
154
+
155
+ conv.append_message(conv.roles[0], inp)
156
+ conv.append_message(conv.roles[1], None)
157
+ prompt = conv.get_prompt()
158
+
159
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
160
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
161
+ keywords = [stop_str]
162
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
163
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
164
+
165
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
166
+ pad_token_id = tokenizer.pad_token_id
167
+
168
+ with torch.inference_mode():
169
+ output_ids = model.generate(
170
+ input_ids,
171
+ images=image_tensor,
172
+ do_sample=True if args.temperature > 0 else False,
173
+ temperature=args.temperature,
174
+ max_new_tokens=args.max_new_tokens,
175
+ streamer=streamer,
176
+ use_cache=True,
177
+ attention_mask=attention_mask,
178
+ pad_token_id=pad_token_id,
179
+ stopping_criteria=[stopping_criteria])
180
+
181
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:],skip_special_tokens=True).strip()
182
+ conv.messages[-1][-1] = outputs
183
+
184
+ if args.debug:
185
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
186
+
187
+ if __name__ == "__main__":
188
+ parser = argparse.ArgumentParser()
189
+ parser.add_argument("--model-path", type=str, default="X-iZhang/libra-v1.0-7b")
190
+ parser.add_argument("--model-base", type=str, default=None)
191
+ parser.add_argument("--image-file", type=str, nargs="+", required=True, help="List of image files to process.")
192
+ parser.add_argument("--device", type=str, default="cuda")
193
+ parser.add_argument("--conv-mode", type=str, default="libra_v1")
194
+ parser.add_argument("--temperature", type=float, default=0.5)
195
+ parser.add_argument("--max-new-tokens", type=int, default=512)
196
+ parser.add_argument("--load-8bit", action="store_true")
197
+ parser.add_argument("--load-4bit", action="store_true")
198
+ parser.add_argument("--debug", action="store_true")
199
+ args = parser.parse_args()
200
+
201
+ main(args)
libra/train/libra_trainer.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from torch.utils.data import Sampler
5
+
6
+ from transformers import Trainer
7
+
8
+ from transformers.trainer import (
9
+ is_sagemaker_mp_enabled,
10
+ get_parameter_names,
11
+ has_length,
12
+ ALL_LAYERNORM_LAYERS,
13
+ logger,
14
+ )
15
+ from typing import List, Optional
16
+
17
+
18
+ def maybe_zero_3(param, ignore_status=False, name=None):
19
+ from deepspeed import zero
20
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
21
+ if hasattr(param, "ds_id"):
22
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
23
+ if not ignore_status:
24
+ print(name, 'no ignore status')
25
+ with zero.GatheredParameters([param]):
26
+ param = param.data.detach().cpu().clone()
27
+ else:
28
+ param = param.detach().cpu().clone()
29
+ return param
30
+
31
+
32
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
33
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
34
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
35
+ return to_return
36
+
37
+
38
+ def split_to_even_chunks(indices, lengths, num_chunks):
39
+ """
40
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
41
+ """
42
+
43
+ if len(indices) % num_chunks != 0:
44
+ return [indices[i::num_chunks] for i in range(num_chunks)]
45
+
46
+ num_indices_per_chunk = len(indices) // num_chunks
47
+
48
+ chunks = [[] for _ in range(num_chunks)]
49
+ chunks_lengths = [0 for _ in range(num_chunks)]
50
+ for index in indices:
51
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
52
+ chunks[shortest_chunk].append(index)
53
+ chunks_lengths[shortest_chunk] += lengths[index]
54
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
55
+ chunks_lengths[shortest_chunk] = float("inf")
56
+
57
+ return chunks
58
+
59
+
60
+ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
61
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
62
+ assert all(l != 0 for l in lengths), "Should not have zero length."
63
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
64
+ # all samples are in the same modality
65
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
66
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
67
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
68
+
69
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
70
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
71
+ megabatch_size = world_size * batch_size
72
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
73
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
74
+
75
+ last_mm = mm_megabatches[-1]
76
+ last_lang = lang_megabatches[-1]
77
+ additional_batch = last_mm + last_lang
78
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
79
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
80
+ megabatches = [megabatches[i] for i in megabatch_indices]
81
+
82
+ if len(additional_batch) > 0:
83
+ megabatches.append(sorted(additional_batch))
84
+
85
+ return [i for megabatch in megabatches for i in megabatch]
86
+
87
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
88
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
89
+ indices = torch.randperm(len(lengths), generator=generator)
90
+ megabatch_size = world_size * batch_size
91
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
92
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
93
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
94
+
95
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
96
+
97
+
98
+ class LengthGroupedSampler(Sampler):
99
+ r"""
100
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
101
+ keeping a bit of randomness.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ batch_size: int,
107
+ world_size: int,
108
+ lengths: Optional[List[int]] = None,
109
+ generator=None,
110
+ group_by_modality: bool = False,
111
+ ):
112
+ if lengths is None:
113
+ raise ValueError("Lengths must be provided.")
114
+
115
+ self.batch_size = batch_size
116
+ self.world_size = world_size
117
+ self.lengths = lengths
118
+ self.generator = generator
119
+ self.group_by_modality = group_by_modality
120
+
121
+ def __len__(self):
122
+ return len(self.lengths)
123
+
124
+ def __iter__(self):
125
+ if self.group_by_modality:
126
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
127
+ else:
128
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
129
+ return iter(indices)
130
+
131
+
132
+ class LibraTrainer(Trainer):
133
+
134
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
135
+ if self.train_dataset is None or not has_length(self.train_dataset):
136
+ return None
137
+
138
+ if self.args.group_by_modality_length:
139
+ lengths = self.train_dataset.modality_lengths
140
+ return LengthGroupedSampler(
141
+ self.args.train_batch_size,
142
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
143
+ lengths=lengths,
144
+ group_by_modality=True,
145
+ )
146
+ else:
147
+ return super()._get_train_sampler()
148
+
149
+
150
+ def create_optimizer(self):
151
+ """
152
+ Setup the optimizer.
153
+
154
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
155
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
156
+ """
157
+ if is_sagemaker_mp_enabled():
158
+ return super().create_optimizer()
159
+
160
+ opt_model = self.model
161
+
162
+ if self.optimizer is None:
163
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
164
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
165
+ if self.args.mm_projector_lr is not None:
166
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
167
+ optimizer_grouped_parameters = [
168
+ {
169
+ "params": [
170
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
171
+ ],
172
+ "weight_decay": self.args.weight_decay,
173
+ },
174
+ {
175
+ "params": [
176
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
177
+ ],
178
+ "weight_decay": 0.0,
179
+ },
180
+ {
181
+ "params": [
182
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
183
+ ],
184
+ "weight_decay": self.args.weight_decay,
185
+ "lr": self.args.mm_projector_lr,
186
+ },
187
+ {
188
+ "params": [
189
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
190
+ ],
191
+ "weight_decay": 0.0,
192
+ "lr": self.args.mm_projector_lr,
193
+ },
194
+ ]
195
+ else:
196
+ optimizer_grouped_parameters = [
197
+ {
198
+ "params": [
199
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
200
+ ],
201
+ "weight_decay": self.args.weight_decay,
202
+ },
203
+ {
204
+ "params": [
205
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
206
+ ],
207
+ "weight_decay": 0.0,
208
+ },
209
+ ]
210
+
211
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
212
+
213
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
214
+ if optimizer_cls.__name__ == "Adam8bit":
215
+ import bitsandbytes
216
+
217
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
218
+
219
+ skipped = 0
220
+ for module in opt_model.modules():
221
+ if isinstance(module, nn.Embedding):
222
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
223
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
224
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
225
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
226
+ logger.info(f"skipped: {skipped/2**20}M params")
227
+
228
+
229
+ return self.optimizer
230
+
231
+ def _save_checkpoint(self, model, trial, metrics=None):
232
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
233
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
234
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
235
+
236
+ run_dir = self._get_output_dir(trial=trial)
237
+ output_dir = os.path.join(run_dir, checkpoint_folder)
238
+
239
+ # Only save Adapter
240
+ keys_to_match = ['mm_projector', 'vision_resampler']
241
+ if getattr(self.args, "use_im_start_end", False):
242
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
243
+
244
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
245
+
246
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
247
+ self.model.config.save_pretrained(output_dir)
248
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
249
+
250
+ super(LibraTrainer, self)._save_checkpoint(model, trial, metrics)
251
+ else:
252
+ super(LibraTrainer, self)._save_checkpoint(model, trial, metrics)
253
+
254
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
255
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
256
+ pass
257
+ else:
258
+ super(LibraTrainer, self)._save(output_dir, state_dict)
libra/train/llama2_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Directly copied the code from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py and made some adjustments
3
+ """
4
+ import warnings
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ from flash_attn import __version__ as flash_attn_version
9
+ from flash_attn.bert_padding import pad_input, unpad_input
10
+ from flash_attn.flash_attn_interface import (
11
+ flash_attn_func,
12
+ flash_attn_varlen_kvpacked_func,
13
+ )
14
+ from transformers.models.llama.modeling_llama import (
15
+ LlamaAttention,
16
+ LlamaModel,
17
+ rotate_half,
18
+ )
19
+
20
+
21
+ def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
22
+ gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
23
+ gather_indices = gather_indices.repeat(
24
+ 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
25
+ )
26
+ bsz = gather_indices.shape[0]
27
+ cos, sin = (
28
+ torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
29
+ for x in cos_sin
30
+ )
31
+ q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
32
+ return q, k
33
+
34
+
35
+ def forward(
36
+ self,
37
+ hidden_states: torch.Tensor,
38
+ attention_mask: Optional[torch.Tensor] = None,
39
+ position_ids: Optional[torch.Tensor] = None,
40
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
41
+ output_attentions: bool = False,
42
+ use_cache: bool = False,
43
+ padding_mask: Optional[torch.Tensor] = None,
44
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
45
+ if output_attentions:
46
+ warnings.warn(
47
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
48
+ )
49
+
50
+ bsz, q_len, _ = hidden_states.size()
51
+ kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
52
+
53
+ q, k, v = (
54
+ op(hidden_states).view(bsz, q_len, nh, self.head_dim)
55
+ for op, nh in (
56
+ (self.q_proj, self.num_heads),
57
+ (self.k_proj, kv_heads),
58
+ (self.v_proj, kv_heads),
59
+ )
60
+ )
61
+ # shape: (b, s, num_heads, head_dim)
62
+
63
+ kv_seq_len = k.shape[1]
64
+ past_kv_len = 0
65
+ if past_key_value is not None:
66
+ past_kv_len = past_key_value[0].shape[2]
67
+ kv_seq_len += past_kv_len
68
+
69
+ cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
70
+ q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
71
+
72
+ if past_key_value is not None:
73
+ assert (
74
+ flash_attn_version >= "2.1.0"
75
+ ), "past_key_value support requires flash-attn >= 2.1.0"
76
+ # reuse k, v
77
+ k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
78
+ v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
79
+
80
+ past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
81
+
82
+ if attention_mask is None:
83
+ output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
84
+ bsz, q_len, -1
85
+ )
86
+ else:
87
+ q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
88
+ # We can skip concat and call unpad twice but seems better to call unpad only once.
89
+ kv, _, cu_k_lens, max_k = unpad_input(
90
+ torch.stack((k, v), dim=2), attention_mask
91
+ )
92
+ output_unpad = flash_attn_varlen_kvpacked_func(
93
+ q,
94
+ kv,
95
+ cu_q_lens,
96
+ cu_k_lens,
97
+ max_s,
98
+ max_k,
99
+ 0.0,
100
+ softmax_scale=None,
101
+ causal=True,
102
+ )
103
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
104
+ output = pad_input(output_unpad, indices, bsz, q_len)
105
+
106
+ return self.o_proj(output), None, past_key_value
107
+
108
+
109
+ # Disable the transformation of the attention mask in LlamaModel as flash attention
110
+ # takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
111
+ def _prepare_decoder_attention_mask(
112
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
113
+ ):
114
+ # [bsz, seq_len]
115
+ if past_key_values_length > 0 and attention_mask is not None:
116
+ attention_mask = torch.cat(
117
+ (
118
+ torch.full(
119
+ (input_shape[0], past_key_values_length),
120
+ True,
121
+ dtype=attention_mask.dtype,
122
+ device=attention_mask.device,
123
+ ),
124
+ attention_mask,
125
+ ),
126
+ dim=-1,
127
+ )
128
+
129
+ if attention_mask is not None and torch.all(attention_mask):
130
+ return None # This uses the faster call when training with full samples
131
+
132
+ return attention_mask
133
+
134
+
135
+ def replace_llama_attn_with_flash_attn():
136
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
137
+ if cuda_major < 8:
138
+ warnings.warn(
139
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
140
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
141
+ )
142
+
143
+ LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
144
+ LlamaAttention.forward = forward
145
+
146
+
147
+ def test():
148
+ from fastchat.train.llama_flash_attn_monkey_patch import forward as fastchat_forward
149
+ from transformers.models.llama.configuration_llama import LlamaConfig
150
+
151
+ config = LlamaConfig(
152
+ hidden_size=1024,
153
+ intermediate_size=128,
154
+ num_hidden_layers=1,
155
+ num_attention_heads=8,
156
+ max_position_embeddings=16,
157
+ )
158
+ device = torch.device("cuda")
159
+ model = LlamaModel(config)
160
+ attn = LlamaAttention(config).to(device).half()
161
+ bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
162
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
163
+ -1, seqlen
164
+ )
165
+
166
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
167
+ for i in range(4):
168
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
169
+ if i:
170
+ mask[0, -i:] = False
171
+ mask[1, :i] = False
172
+
173
+ lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
174
+ ref, _, _ = attn.forward(
175
+ hidden, attention_mask=lmask, position_ids=position_ids
176
+ )
177
+
178
+ fast, _, _ = fastchat_forward(
179
+ attn, hidden, attention_mask=mask, position_ids=position_ids
180
+ )
181
+
182
+ lmask = _prepare_decoder_attention_mask(
183
+ model, mask, hidden.shape[:2], hidden, 0
184
+ )
185
+ test, _, _ = forward(
186
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
187
+ )
188
+
189
+ print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}")
190
+ print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}")
191
+ print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}")
192
+ print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}")
193
+ print(f"allclose(fast, test) = {torch.allclose(fast, test)}")
194
+
195
+ with torch.no_grad():
196
+ # Also check that past_kv is handled properly
197
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
198
+ part_len = seqlen // 4
199
+ assert part_len * 4 == seqlen
200
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
201
+ mask[0, -2:] = False
202
+ lmask = _prepare_decoder_attention_mask(
203
+ model, mask, hidden.shape[:2], hidden, 0
204
+ )
205
+ oneshot, _, _ = forward(
206
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
207
+ )
208
+ parts = []
209
+ past_kv, past_kv_len = None, 0
210
+ for i in range(4):
211
+ start = part_len * i
212
+ end = start + part_len
213
+ hidden_part = hidden[:, start:end, ...]
214
+ lmask = _prepare_decoder_attention_mask(
215
+ model,
216
+ mask[:, start:end],
217
+ hidden_part.shape[:2],
218
+ hidden_part,
219
+ past_kv_len,
220
+ )
221
+ part, _, past_kv = forward(
222
+ attn,
223
+ hidden_part.clone(),
224
+ attention_mask=lmask,
225
+ position_ids=position_ids[:, start:end],
226
+ past_key_value=past_kv,
227
+ use_cache=True,
228
+ )
229
+ parts.append(part)
230
+ past_kv_len = past_kv[0].shape[2]
231
+
232
+ print(
233
+ f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
234
+ )
235
+ print(
236
+ f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}"
237
+ )
238
+
239
+
240
+ if __name__ == "__main__":
241
+ test()
libra/train/llama_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Directly copied the code from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py and made some adjustments
3
+ """
4
+
5
+ from typing import Optional, Tuple
6
+ import warnings
7
+
8
+ import torch
9
+
10
+ import transformers
11
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
12
+
13
+ try:
14
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
15
+ except ImportError:
16
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
17
+ from flash_attn.bert_padding import unpad_input, pad_input
18
+
19
+
20
+ def forward(
21
+ self,
22
+ hidden_states: torch.Tensor,
23
+ attention_mask: Optional[torch.Tensor] = None,
24
+ position_ids: Optional[torch.Tensor] = None,
25
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
26
+ output_attentions: bool = False,
27
+ use_cache: bool = False,
28
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
29
+ if output_attentions:
30
+ warnings.warn(
31
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
32
+ )
33
+
34
+ bsz, q_len, _ = hidden_states.size()
35
+
36
+ query_states = (
37
+ self.q_proj(hidden_states)
38
+ .view(bsz, q_len, self.num_heads, self.head_dim)
39
+ .transpose(1, 2)
40
+ )
41
+ key_states = (
42
+ self.k_proj(hidden_states)
43
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
44
+ .transpose(1, 2)
45
+ )
46
+ value_states = (
47
+ self.v_proj(hidden_states)
48
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
49
+ .transpose(1, 2)
50
+ ) # shape: (b, num_heads, s, head_dim)
51
+
52
+ kv_seq_len = key_states.shape[-2]
53
+ if past_key_value is not None:
54
+ kv_seq_len += past_key_value[0].shape[-2]
55
+
56
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
57
+ query_states, key_states = apply_rotary_pos_emb(
58
+ query_states, key_states, cos, sin, position_ids
59
+ )
60
+
61
+ if past_key_value is not None:
62
+ # reuse k, v
63
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
64
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
65
+
66
+ past_key_value = (key_states, value_states) if use_cache else None
67
+
68
+ # repeat k/v heads if n_kv_heads < n_heads
69
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
70
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
71
+
72
+ # Transform the data into the format required by flash attention
73
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
74
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
75
+ key_padding_mask = attention_mask
76
+
77
+ if key_padding_mask is None:
78
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
79
+ cu_q_lens = torch.arange(
80
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
81
+ )
82
+ max_s = q_len
83
+ output = flash_attn_unpadded_qkvpacked_func(
84
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
85
+ )
86
+ output = output.view(bsz, q_len, -1)
87
+ else:
88
+ qkv = qkv.reshape(bsz, q_len, -1)
89
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
90
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
91
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
92
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
93
+ )
94
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
95
+ output = pad_input(output_unpad, indices, bsz, q_len)
96
+
97
+ return self.o_proj(output), None, past_key_value
98
+
99
+
100
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
101
+ # requires the attention mask to be the same as the key_padding_mask
102
+ def _prepare_decoder_attention_mask(
103
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
104
+ ):
105
+ # [bsz, seq_len]
106
+ return attention_mask
107
+
108
+
109
+ def replace_llama_attn_with_flash_attn():
110
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
111
+ if cuda_major < 8:
112
+ warnings.warn(
113
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
114
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
115
+ )
116
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
117
+ _prepare_decoder_attention_mask
118
+ )
119
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
libra/train/llama_xformers_attn_monkey_patch.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Directly copied the code from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_xformers_attn_monkey_patch.py and made some adjustments
3
+ """
4
+
5
+ import logging
6
+ import math
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import transformers.models.llama.modeling_llama
11
+ from torch import nn
12
+
13
+ try:
14
+ import xformers.ops
15
+ except ImportError:
16
+ logging.error("xformers not found! Please install it before trying to use it.")
17
+
18
+
19
+ def replace_llama_attn_with_xformers_attn():
20
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
21
+
22
+
23
+ def xformers_forward(
24
+ self,
25
+ hidden_states: torch.Tensor,
26
+ attention_mask: Optional[torch.Tensor] = None,
27
+ position_ids: Optional[torch.LongTensor] = None,
28
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
29
+ output_attentions: bool = False,
30
+ use_cache: bool = False,
31
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
32
+ # pylint: disable=duplicate-code
33
+ bsz, q_len, _ = hidden_states.size()
34
+
35
+ query_states = (
36
+ self.q_proj(hidden_states)
37
+ .view(bsz, q_len, self.num_heads, self.head_dim)
38
+ .transpose(1, 2)
39
+ )
40
+ key_states = (
41
+ self.k_proj(hidden_states)
42
+ .view(bsz, q_len, self.num_heads, self.head_dim)
43
+ .transpose(1, 2)
44
+ )
45
+ value_states = (
46
+ self.v_proj(hidden_states)
47
+ .view(bsz, q_len, self.num_heads, self.head_dim)
48
+ .transpose(1, 2)
49
+ )
50
+
51
+ kv_seq_len = key_states.shape[-2]
52
+ if past_key_value is not None:
53
+ kv_seq_len += past_key_value[0].shape[-2]
54
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55
+ (
56
+ query_states,
57
+ key_states,
58
+ ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
59
+ query_states, key_states, cos, sin, position_ids
60
+ )
61
+ # [bsz, nh, t, hd]
62
+
63
+ if past_key_value is not None:
64
+ # reuse k, v, self_attention
65
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
66
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
67
+
68
+ past_key_value = (key_states, value_states) if use_cache else None
69
+
70
+ # We only apply xformers optimizations if we don't need to output the whole attention matrix
71
+ if not output_attentions:
72
+ query_states = query_states.transpose(1, 2)
73
+ key_states = key_states.transpose(1, 2)
74
+ value_states = value_states.transpose(1, 2)
75
+
76
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
77
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
78
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
79
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
80
+ attn_output = xformers.ops.memory_efficient_attention(
81
+ query_states, key_states, value_states, attn_bias=None
82
+ )
83
+ else:
84
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
85
+ attn_output = xformers.ops.memory_efficient_attention(
86
+ query_states,
87
+ key_states,
88
+ value_states,
89
+ attn_bias=xformers.ops.LowerTriangularMask(),
90
+ )
91
+ attn_weights = None
92
+ else:
93
+ attn_weights = torch.matmul(
94
+ query_states, key_states.transpose(2, 3)
95
+ ) / math.sqrt(self.head_dim)
96
+
97
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
98
+ raise ValueError(
99
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
100
+ f" {attn_weights.size()}"
101
+ )
102
+
103
+ if attention_mask is not None:
104
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
105
+ raise ValueError(
106
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
107
+ )
108
+ attn_weights = attn_weights + attention_mask
109
+ attn_weights = torch.max(
110
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
111
+ )
112
+
113
+ # upcast attention to fp32
114
+ attn_weights = nn.functional.softmax(
115
+ attn_weights, dim=-1, dtype=torch.float32
116
+ ).to(query_states.dtype)
117
+ attn_output = torch.matmul(attn_weights, value_states)
118
+
119
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
120
+ raise ValueError(
121
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
122
+ f" {attn_output.size()}"
123
+ )
124
+
125
+ attn_output = attn_output.transpose(1, 2)
126
+
127
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
128
+ attn_output = self.o_proj(attn_output)
129
+ return attn_output, attn_weights, past_key_value
libra/train/train.py ADDED
@@ -0,0 +1,1434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xi Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import copy
17
+ import numpy as np
18
+ from dataclasses import dataclass, field
19
+ import json
20
+ import logging
21
+ import pathlib
22
+ from typing import Dict, Optional, Sequence, List, Union
23
+
24
+ import random
25
+ import torch
26
+ import shutil
27
+ import evaluate
28
+
29
+ import transformers
30
+ import tokenizers
31
+ from transformers import EvalPrediction, TrainerCallback
32
+
33
+ from libra.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
34
+ from torch.utils.data import Dataset
35
+ from libra.train.libra_trainer import LibraTrainer
36
+
37
+ from libra import conversation as conversation_lib
38
+ from libra.model import *
39
+ from libra.mm_utils import tokenizer_image_token
40
+ from libra.eval import temporal_f1_score
41
+
42
+ from PIL import Image
43
+ import pydicom
44
+ from pydicom.pixel_data_handlers.util import apply_voi_lut
45
+
46
+ local_rank = None
47
+
48
+
49
+ def rank0_print(*args):
50
+ if local_rank == 0:
51
+ print(*args)
52
+
53
+ from packaging import version
54
+ IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
55
+
56
+ @dataclass
57
+ class ModelArguments:
58
+ model_name_or_path: Optional[str] = field(default="libra")
59
+ version: Optional[str] = field(default="libra_v1")
60
+ freeze_backbone: bool = field(default=False)
61
+ tune_mm_mlp_adapter: bool = field(default=False)
62
+ vision_tower: Optional[str] = field(default=None)
63
+ mm_vision_select_layer: Optional[Union[int, str]] = field(
64
+ default=-1,
65
+ metadata={"help": "Select specific vision layer (e.g., -1, -2) or 'all' for all layers."}
66
+ )
67
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
68
+ mm_projector_type: Optional[str] = field(default='linear')
69
+ mm_use_im_start_end: bool = field(default=False)
70
+ mm_use_im_patch_token: bool = field(default=True)
71
+ mm_vision_select_feature: Optional[str] = field(
72
+ default="patch",
73
+ metadata={"help": "Select feature type: 'patch' or 'cls_patch'."}
74
+ )
75
+ compute_metrics: bool = field(
76
+ default=False,
77
+ metadata={"help": "Optional callable for computing metrics during evaluation during training."}
78
+ )
79
+
80
+ @dataclass
81
+ class DataArguments:
82
+ data_path: str = field(default=None,
83
+ metadata={"help": "Path to the training data."})
84
+ lazy_preprocess: bool = False
85
+ is_multimodal: bool = False
86
+ image_folder: Optional[str] = field(default=None)
87
+ image_aspect_ratio: str = 'square'
88
+ validation_data_path: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Path to the validation data."}
91
+ )
92
+
93
+
94
+ @dataclass
95
+ class TrainingArguments(transformers.TrainingArguments):
96
+ cache_dir: Optional[str] = field(default=None)
97
+ optim: str = field(default="adamw_torch")
98
+ remove_unused_columns: bool = field(default=False)
99
+ freeze_mm_mlp_adapter: bool = field(default=False)
100
+ mpt_attn_impl: Optional[str] = field(default="triton")
101
+ model_max_length: int = field(
102
+ default=512,
103
+ metadata={
104
+ "help":
105
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
106
+ },
107
+ )
108
+ double_quant: bool = field(
109
+ default=True,
110
+ metadata={"help": "Compress the quantization statistics through double quantization."}
111
+ )
112
+ quant_type: str = field(
113
+ default="nf4",
114
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
115
+ )
116
+ bits: int = field(
117
+ default=16,
118
+ metadata={"help": "How many bits to use."}
119
+ )
120
+ lora_enable: bool = False
121
+ lora_r: int = 64
122
+ lora_alpha: int = 16
123
+ lora_dropout: float = 0.05
124
+ lora_weight_path: str = ""
125
+ lora_bias: str = "none"
126
+ mm_projector_lr: Optional[float] = None
127
+ group_by_modality_length: bool = field(default=False)
128
+
129
+
130
+ def maybe_zero_3(param, ignore_status=False, name=None):
131
+ from deepspeed import zero
132
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
133
+ if hasattr(param, "ds_id"):
134
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
135
+ if not ignore_status:
136
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
137
+ with zero.GatheredParameters([param]):
138
+ param = param.data.detach().cpu().clone()
139
+ else:
140
+ param = param.detach().cpu().clone()
141
+ return param
142
+
143
+
144
+ # Borrowed from peft.utils.get_peft_model_state_dict
145
+ def get_peft_state_maybe_zero_3(named_params, bias):
146
+ if bias == "none":
147
+ to_return = {k: t for k, t in named_params if "lora_" in k}
148
+ elif bias == "all":
149
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
150
+ elif bias == "lora_only":
151
+ to_return = {}
152
+ maybe_lora_bias = {}
153
+ lora_bias_names = set()
154
+ for k, t in named_params:
155
+ if "lora_" in k:
156
+ to_return[k] = t
157
+ bias_name = k.split("lora_")[0] + "bias"
158
+ lora_bias_names.add(bias_name)
159
+ elif "bias" in k:
160
+ maybe_lora_bias[k] = t
161
+ for k, t in maybe_lora_bias:
162
+ if bias_name in lora_bias_names:
163
+ to_return[bias_name] = t
164
+ else:
165
+ raise NotImplementedError
166
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
167
+ return to_return
168
+
169
+
170
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
171
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
172
+ if require_grad_only:
173
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
174
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
175
+ return to_return
176
+
177
+
178
+ def get_non_vision_tower_state_maybe_zero_3(named_params):
179
+
180
+ to_return = {k: t for k, t in named_params if "vision_tower" not in k}
181
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
182
+
183
+ return to_return
184
+
185
+
186
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
187
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
188
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
189
+ return to_return
190
+
191
+
192
+ def find_all_linear_names(model):
193
+ cls = torch.nn.Linear
194
+ lora_module_names = set()
195
+ multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
196
+ for name, module in model.named_modules():
197
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
198
+ continue
199
+ if isinstance(module, cls):
200
+ names = name.split('.')
201
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
202
+
203
+ if 'lm_head' in lora_module_names: # needed for 16-bit
204
+ lora_module_names.remove('lm_head')
205
+ return list(lora_module_names)
206
+
207
+
208
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
209
+ output_dir: str):
210
+ """Collects the state dict and dump to disk."""
211
+
212
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False):
213
+ # Only save Adapter
214
+ keys_to_match = ['mm_projector']
215
+ if getattr(trainer.args, "use_im_start_end", False):
216
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
217
+
218
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
219
+ trainer.model.config.save_pretrained(output_dir)
220
+
221
+ current_folder = output_dir.split('/')[-1]
222
+ parent_folder = os.path.dirname(output_dir)
223
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
224
+ if current_folder.startswith('checkpoint-'):
225
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
226
+ os.makedirs(mm_projector_folder, exist_ok=True)
227
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
228
+ else:
229
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
230
+ return
231
+
232
+ if trainer.deepspeed:
233
+ torch.cuda.synchronize()
234
+ trainer.save_model(output_dir)
235
+ return
236
+
237
+ state_dict = trainer.model.state_dict()
238
+ if trainer.args.should_save:
239
+ cpu_state_dict = {
240
+ key: value.cpu()
241
+ for key, value in state_dict.items()
242
+ }
243
+ del state_dict
244
+ trainer._save(output_dir, state_dict=cpu_state_dict)
245
+
246
+
247
+ def smart_tokenizer_and_embedding_resize(
248
+ special_tokens_dict: Dict,
249
+ tokenizer: transformers.PreTrainedTokenizer,
250
+ model: transformers.PreTrainedModel,
251
+ ):
252
+ """Resize tokenizer and embedding. You can add some new tokens <video> etc
253
+
254
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
255
+ """
256
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
257
+ model.resize_token_embeddings(len(tokenizer))
258
+
259
+ if num_new_tokens > 0:
260
+ input_embeddings = model.get_input_embeddings().weight.data
261
+ output_embeddings = model.get_output_embeddings().weight.data
262
+
263
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
264
+ dim=0, keepdim=True)
265
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
266
+ dim=0, keepdim=True)
267
+
268
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
269
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
270
+
271
+
272
+ def _tokenize_fn(strings: Sequence[str],
273
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
274
+ """
275
+ Tokenizes a list of input strings and returns tokenized results along with sequence lengths.
276
+ """
277
+ tokenized_list = [
278
+ tokenizer(
279
+ text,
280
+ return_tensors="pt",
281
+ padding="longest",
282
+ max_length=tokenizer.model_max_length,
283
+ truncation=True,
284
+ ) for text in strings
285
+ ]
286
+ input_ids = labels = [
287
+ tokenized.input_ids[0] for tokenized in tokenized_list
288
+ ]
289
+ input_ids_lens = labels_lens = [
290
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
291
+ for tokenized in tokenized_list
292
+ ]
293
+ return dict(
294
+ input_ids=input_ids,
295
+ labels=labels,
296
+ input_ids_lens=input_ids_lens,
297
+ labels_lens=labels_lens,
298
+ )
299
+
300
+
301
+ def _mask_targets(target, tokenized_lens, speakers):
302
+ # cur_idx = 0
303
+ cur_idx = tokenized_lens[0]
304
+ tokenized_lens = tokenized_lens[1:]
305
+ target[:cur_idx] = IGNORE_INDEX
306
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
307
+ if speaker == "human":
308
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
309
+ cur_idx += tokenized_len
310
+
311
+
312
+ def _add_speaker_and_signal(header, source, get_conversation=True):
313
+ """Add speaker and start/end signal on each round."""
314
+ BEGIN_SIGNAL = "### "
315
+ END_SIGNAL = "\n"
316
+ conversation = header
317
+ for sentence in source:
318
+ from_str = sentence["from"]
319
+ if from_str.lower() == "human":
320
+ from_str = conversation_lib.default_conversation.roles[0]
321
+ elif from_str.lower() == "gpt":
322
+ from_str = conversation_lib.default_conversation.roles[1]
323
+ else:
324
+ from_str = 'unknown'
325
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
326
+ sentence["value"] + END_SIGNAL)
327
+ if get_conversation:
328
+ conversation += sentence["value"]
329
+ conversation += BEGIN_SIGNAL
330
+ return conversation
331
+
332
+
333
+ def preprocess_multimodal(
334
+ sources: Sequence[str],
335
+ data_args: DataArguments
336
+ ) -> Dict:
337
+ is_multimodal = data_args.is_multimodal
338
+ if not is_multimodal:
339
+ return sources
340
+
341
+ for source in sources:
342
+ for sentence in source:
343
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
344
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
345
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
346
+ sentence['value'] = sentence['value'].strip()
347
+ if "mmtag" in conversation_lib.default_conversation.version:
348
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
349
+ replace_token = DEFAULT_IMAGE_TOKEN
350
+ if data_args.mm_use_im_start_end:
351
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
352
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
353
+ return sources
354
+
355
+
356
+ def preprocess_llama_2(
357
+ sources,
358
+ tokenizer: transformers.PreTrainedTokenizer,
359
+ has_image: bool = False
360
+ ) -> Dict:
361
+ conv = conversation_lib.default_conversation.copy()
362
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
363
+
364
+ # Apply prompt templates
365
+ conversations = []
366
+ for i, source in enumerate(sources):
367
+ if roles[source[0]["from"]] != conv.roles[0]:
368
+ # Skip the first one if it is not from human
369
+ source = source[1:]
370
+
371
+ conv.messages = []
372
+ for j, sentence in enumerate(source):
373
+ role = roles[sentence["from"]]
374
+ assert role == conv.roles[j % 2], f"{i}"
375
+ conv.append_message(role, sentence["value"])
376
+ conversations.append(conv.get_prompt())
377
+
378
+ # Tokenize conversations
379
+
380
+ if has_image:
381
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
382
+ else:
383
+ input_ids = tokenizer(
384
+ conversations,
385
+ return_tensors="pt",
386
+ padding="longest",
387
+ max_length=tokenizer.model_max_length,
388
+ truncation=True,
389
+ ).input_ids
390
+
391
+ targets = input_ids.clone()
392
+
393
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
394
+
395
+ # Mask targets
396
+ sep = "[/INST] "
397
+ for conversation, target in zip(conversations, targets):
398
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
399
+
400
+ rounds = conversation.split(conv.sep2)
401
+ cur_len = 1
402
+ target[:cur_len] = IGNORE_INDEX
403
+ for i, rou in enumerate(rounds):
404
+ if rou == "":
405
+ break
406
+
407
+ parts = rou.split(sep)
408
+ if len(parts) != 2:
409
+ break
410
+ parts[0] += sep
411
+
412
+ if has_image:
413
+ round_len = len(tokenizer_image_token(rou, tokenizer))
414
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
415
+ else:
416
+ round_len = len(tokenizer(rou).input_ids)
417
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
418
+
419
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
420
+
421
+ cur_len += round_len
422
+ target[cur_len:] = IGNORE_INDEX
423
+
424
+ if cur_len < tokenizer.model_max_length:
425
+ if cur_len != total_len:
426
+ target[:] = IGNORE_INDEX
427
+ print(
428
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
429
+ f" (ignored)"
430
+ )
431
+
432
+ return dict(
433
+ input_ids=input_ids,
434
+ labels=targets,
435
+ )
436
+
437
+ # llama_3
438
+ def preprocess_llama_3(
439
+ sources,
440
+ tokenizer: transformers.PreTrainedTokenizer,
441
+ has_image: bool = False
442
+ ) -> Dict:
443
+
444
+ special_token = "<|finetune_right_pad_id|>"
445
+
446
+ if tokenizer.pad_token_id is None:
447
+
448
+ pad_token_id = tokenizer.convert_tokens_to_ids(special_token)
449
+ if pad_token_id is None:
450
+ raise ValueError(f"Cannot find ID for {special_token}. Please check the tokenizer.")
451
+
452
+ tokenizer.pad_token_id = pad_token_id
453
+
454
+ conv = conversation_lib.default_conversation.copy()
455
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
456
+
457
+ # Apply prompt templates
458
+ conversations = []
459
+ for i, source in enumerate(sources):
460
+ if roles[source[0]["from"]] != conv.roles[0]:
461
+ # Skip the first one if it is not from human
462
+ source = source[1:]
463
+
464
+ conv.messages = []
465
+ for j, sentence in enumerate(source):
466
+ role = roles[sentence["from"]]
467
+ assert role == conv.roles[j % 2], f"{i}"
468
+ conv.append_message(role, sentence["value"])
469
+ conversations.append(conv.get_prompt())
470
+
471
+ if has_image:
472
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
473
+ else:
474
+ input_ids = tokenizer(
475
+ conversations,
476
+ return_tensors="pt",
477
+ padding="longest",
478
+ max_length=tokenizer.model_max_length,
479
+ truncation=True,
480
+ ).input_ids
481
+
482
+ targets = input_ids.clone()
483
+
484
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3
485
+
486
+ sep_round = "<|eot_id|>\n<|start_header_id|>user<|end_header_id|>"
487
+ sep_user = "<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"
488
+ for conversation, target in zip(conversations, targets):
489
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
490
+ rounds = conversation.split(sep_round)
491
+ cur_len = 1
492
+ target[:cur_len] = IGNORE_INDEX
493
+ for i, rou in enumerate(rounds):
494
+ if rou == "":
495
+ break
496
+
497
+ parts = rou.split(sep_user)
498
+ if len(parts) != 2:
499
+ break
500
+ parts[0] += sep_user
501
+
502
+ if has_image:
503
+ round_len = len(tokenizer_image_token(rou, tokenizer)) - 1
504
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
505
+ else:
506
+ round_len = len(tokenizer(rou).input_ids) - 1
507
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
508
+
509
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
510
+
511
+ cur_len += round_len
512
+ target[cur_len:] = IGNORE_INDEX
513
+
514
+ if cur_len < tokenizer.model_max_length:
515
+ if cur_len != total_len:
516
+ target[:] = IGNORE_INDEX
517
+ print(
518
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
519
+ f" (ignored)"
520
+ )
521
+
522
+ return dict(
523
+ input_ids=input_ids,
524
+ labels=targets,
525
+ )
526
+
527
+ def preprocess_libra(
528
+ sources,
529
+ tokenizer: transformers.PreTrainedTokenizer,
530
+ has_image: bool = False
531
+ ) -> Dict:
532
+ conv = conversation_lib.default_conversation.copy()
533
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
534
+
535
+ # Apply prompt templates
536
+ conversations = []
537
+ for i, source in enumerate(sources):
538
+ if roles[source[0]["from"]] != conv.roles[0]:
539
+ # Skip the first one if it is not from human
540
+ source = source[1:]
541
+
542
+ conv.messages = []
543
+ for j, sentence in enumerate(source):
544
+ role = roles[sentence["from"]]
545
+ assert role == conv.roles[j % 2], f"{i}"
546
+ conv.append_message(role, sentence["value"])
547
+ conversations.append(conv.get_prompt())
548
+
549
+ # Tokenize conversations
550
+ if has_image:
551
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
552
+ else:
553
+ input_ids = tokenizer(
554
+ conversations,
555
+ return_tensors="pt",
556
+ padding="longest",
557
+ max_length=tokenizer.model_max_length,
558
+ truncation=True,
559
+ ).input_ids
560
+
561
+ targets = input_ids.clone()
562
+
563
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
564
+
565
+ # Mask targets
566
+ sep = conv.sep + conv.roles[1] + ": "
567
+ for conversation, target in zip(conversations, targets):
568
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
569
+
570
+
571
+ rounds = conversation.split(conv.sep2)
572
+ cur_len = 1
573
+ target[:cur_len] = IGNORE_INDEX
574
+ for i, rou in enumerate(rounds):
575
+ if rou == "":
576
+ break
577
+
578
+ parts = rou.split(sep)
579
+ if len(parts) != 2:
580
+ break
581
+ parts[0] += sep
582
+
583
+ if has_image:
584
+ round_len = len(tokenizer_image_token(rou, tokenizer))
585
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
586
+ else:
587
+ round_len = len(tokenizer(rou).input_ids)
588
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
589
+
590
+ if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
591
+ round_len -= 1
592
+ instruction_len -= 1
593
+
594
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
595
+
596
+ cur_len += round_len
597
+ target[cur_len:] = IGNORE_INDEX
598
+
599
+ if cur_len < tokenizer.model_max_length:
600
+ if cur_len != total_len:
601
+ target[:] = IGNORE_INDEX
602
+ print(
603
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
604
+ f" (ignored)"
605
+ )
606
+
607
+ return dict(
608
+ input_ids=input_ids,
609
+ labels=targets,
610
+ )
611
+
612
+
613
+ def preprocess_mpt(
614
+ sources,
615
+ tokenizer: transformers.PreTrainedTokenizer,
616
+ has_image: bool = False
617
+ ) -> Dict:
618
+ conv = conversation_lib.default_conversation.copy()
619
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
620
+
621
+ # Apply prompt templates
622
+ conversations = []
623
+ for i, source in enumerate(sources):
624
+ if roles[source[0]["from"]] != conv.roles[0]:
625
+ # Skip the first one if it is not from human
626
+ source = source[1:]
627
+
628
+ conv.messages = []
629
+ for j, sentence in enumerate(source):
630
+ role = roles[sentence["from"]]
631
+ assert role == conv.roles[j % 2], f"{i}"
632
+ conv.append_message(role, sentence["value"])
633
+ conversations.append(conv.get_prompt())
634
+
635
+ # Tokenize conversations
636
+
637
+ if has_image:
638
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
639
+ else:
640
+ input_ids = tokenizer(
641
+ conversations,
642
+ return_tensors="pt",
643
+ padding="longest",
644
+ max_length=tokenizer.model_max_length,
645
+ truncation=True,
646
+ ).input_ids
647
+
648
+ targets = input_ids.clone()
649
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
650
+
651
+ # Mask targets
652
+ sep = conv.sep + conv.roles[1]
653
+ for conversation, target in zip(conversations, targets):
654
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
655
+
656
+ rounds = conversation.split(conv.sep)
657
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
658
+ for conv_idx in range(3, len(rounds), 2):
659
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
660
+ cur_len = 0
661
+ target[:cur_len] = IGNORE_INDEX
662
+ for i, rou in enumerate(re_rounds):
663
+ if rou == "":
664
+ break
665
+
666
+ parts = rou.split(sep)
667
+ if len(parts) != 2:
668
+ break
669
+ parts[0] += sep
670
+
671
+ if has_image:
672
+ round_len = len(tokenizer_image_token(rou, tokenizer))
673
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
674
+ else:
675
+ round_len = len(tokenizer(rou).input_ids)
676
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
677
+
678
+ if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
679
+ round_len += 1
680
+ instruction_len += 1
681
+
682
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
683
+
684
+ cur_len += round_len
685
+ target[cur_len:] = IGNORE_INDEX
686
+
687
+ if cur_len < tokenizer.model_max_length:
688
+ if cur_len != total_len:
689
+ target[:] = IGNORE_INDEX
690
+ print(
691
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
692
+ f" (ignored)"
693
+ )
694
+
695
+ return dict(
696
+ input_ids=input_ids,
697
+ labels=targets,
698
+ )
699
+
700
+
701
+ def preprocess_plain(
702
+ sources: Sequence[str],
703
+ tokenizer: transformers.PreTrainedTokenizer,
704
+ ) -> Dict:
705
+ # add end signal and concatenate together
706
+ conversations = []
707
+ for source in sources:
708
+ assert len(source) == 2
709
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
710
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
711
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
712
+ conversations.append(conversation)
713
+ # tokenize conversations
714
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
715
+ targets = copy.deepcopy(input_ids)
716
+ for target, source in zip(targets, sources):
717
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
718
+ target[:tokenized_len] = IGNORE_INDEX
719
+
720
+ return dict(input_ids=input_ids, labels=targets)
721
+
722
+
723
+ def load_images(image_file):
724
+ """
725
+ Load an image from a local file, a URL, or a DICOM file.
726
+
727
+ Args:
728
+ image_file (str): The path or URL of the image file to load.
729
+
730
+ Returns:
731
+ PIL.Image.Image: The loaded image in RGB format.
732
+
733
+ Raises:
734
+ ValueError: If the DICOM file does not contain image data.
735
+ TypeError: If the input is neither a valid file path nor a URL.
736
+ """
737
+ if isinstance(image_file, str):
738
+ # Case 1: Load from URL
739
+ if image_file.startswith(('http://', 'https://')):
740
+ try:
741
+ response = requests.get(image_file)
742
+ response.raise_for_status()
743
+ image = Image.open(BytesIO(response.content)).convert('RGB')
744
+ except Exception as e:
745
+ raise ValueError(f"Error loading image from URL: {image_file}\n{e}")
746
+
747
+ # Case 2: Load from DICOM file
748
+ elif image_file.lower().endswith('.dcm'):
749
+ try:
750
+ dicom = pydicom.dcmread(image_file)
751
+ if 'PixelData' in dicom:
752
+ data = apply_voi_lut(dicom.pixel_array, dicom)
753
+
754
+ # Handle MONOCHROME1 images
755
+ if dicom.PhotometricInterpretation == "MONOCHROME1":
756
+ data = np.max(data) - data
757
+
758
+ # Normalize the image data
759
+ data = data - np.min(data)
760
+ data = data / np.max(data)
761
+ data = (data * 255).astype(np.uint8)
762
+
763
+ # Convert to 3-channel RGB if necessary
764
+ if data.ndim == 2:
765
+ data = np.stack([data] * 3, axis=-1)
766
+
767
+ image = Image.fromarray(data).convert('RGB')
768
+ else:
769
+ raise ValueError("DICOM file does not contain image data")
770
+ except Exception as e:
771
+ raise ValueError(f"Error loading DICOM file: {image_file}\n{e}")
772
+
773
+ # Case 3: Load standard image files (e.g., PNG, JPG)
774
+ else:
775
+ try:
776
+ image = Image.open(image_file).convert('RGB')
777
+ except Exception as e:
778
+ raise ValueError(f"Error loading standard image file: {image_file}\n{e}")
779
+
780
+ else:
781
+ raise TypeError("image_file must be a string representing a file path or URL")
782
+
783
+ return image
784
+
785
+
786
+ def preprocess(
787
+ sources: Sequence[str],
788
+ tokenizer: transformers.PreTrainedTokenizer,
789
+ has_image: bool = False
790
+ ) -> Dict:
791
+ """
792
+ Given a list of sources, each is a conversation list. This transform:
793
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
794
+ 2. Concatenate conversations together;
795
+ 3. Tokenize the concatenated conversation;
796
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
797
+ """
798
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
799
+ return preprocess_plain(sources, tokenizer)
800
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
801
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
802
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_3:
803
+ return preprocess_llama_3(sources, tokenizer, has_image=has_image)
804
+ if conversation_lib.default_conversation.version.startswith("v1"):
805
+ return preprocess_libra(sources, tokenizer, has_image=has_image)
806
+ if conversation_lib.default_conversation.version == "mpt":
807
+ return preprocess_mpt(sources, tokenizer)
808
+
809
+ conversations = []
810
+ for source in sources:
811
+ header = f"{conversation_lib.default_conversation.system}\n\n"
812
+ conversation = _add_speaker_and_signal(header, source)
813
+ conversations.append(conversation)
814
+ # tokenize conversations
815
+ def get_tokenize_len(prompts):
816
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
817
+
818
+ if has_image:
819
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
820
+ else:
821
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
822
+ input_ids = conversations_tokenized["input_ids"]
823
+
824
+ targets = copy.deepcopy(input_ids)
825
+ for target, source in zip(targets, sources):
826
+ if has_image:
827
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
828
+ else:
829
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
830
+ speakers = [sentence["from"] for sentence in source]
831
+ _mask_targets(target, tokenized_lens, speakers)
832
+
833
+ return dict(input_ids=input_ids, labels=targets)
834
+
835
+
836
+ def create_compute_metrics(tokenizer, num_patches: int, sep2: str):
837
+ """
838
+ Creates a function to compute evaluation metrics (e.g., BLEU, ROUGE-L, Temple-F1) for the model.
839
+ based on the given tokenizer and 'num_patches' parameter.
840
+
841
+ Args:
842
+ tokenizer: The tokenizer used for encoding/decoding text.
843
+ num_patches (int): The number of patches to be adjusted in the labels.
844
+ sep2 (str): A separator token used to identify a special token ID.
845
+
846
+ Returns:
847
+ A callable function 'compute_metrics(eval_pred)' that computes evaluation metrics.
848
+ """
849
+ # Pre-fetch special token IDs to avoid repeated calls
850
+ bos_token_id = tokenizer.convert_tokens_to_ids(sep2)
851
+ newline_token_id = tokenizer.convert_tokens_to_ids('<0x0A>')
852
+ # 0 is commonly used as the <pad> token ID
853
+ special_token_ids = [bos_token_id, newline_token_id, 0]
854
+
855
+ # Pre-load evaluation metrics (adjust if needed for your scenario)
856
+ bleu_metric = evaluate.load("bleu")
857
+ rouge_metric = evaluate.load("rouge")
858
+
859
+ def compute_metrics(eval_pred: EvalPrediction) -> dict:
860
+ """
861
+ Compute various evaluation metrics including BLEU, ROUGE, F1 for RadGraph and CheXbert, and BERTScore.
862
+
863
+ Args:
864
+ eval_pred (EvalPrediction): Contains model predictions and true labels.
865
+
866
+ Returns:
867
+ dict: Dictionary containing evaluation metric scores.
868
+ """
869
+ logits, labels = eval_pred.predictions, eval_pred.label_ids
870
+ predicted_ids = np.argmax(logits, axis=-1)
871
+
872
+ # Store processed predicted token IDs
873
+ processed_predicted_ids = []
874
+
875
+ for label, predicted in zip(labels, predicted_ids):
876
+ # (1) Find ignore_count: the position of the first non-IGNORE_INDEX token in the label
877
+ ignore_count = next(
878
+ (i for i, token in enumerate(label) if token != IGNORE_INDEX),
879
+ len(label) # If all are -100, use the length of the label
880
+ )
881
+
882
+ # (2) Calculate the truncation start index
883
+ # This depends on 'num_patches' and the ignored tokens.
884
+ start_index = ignore_count + num_patches - 2
885
+
886
+ # If start_index exceeds the predicted sequence length, append an empty list
887
+ if start_index >= len(predicted):
888
+ processed_predicted_ids.append([])
889
+ continue
890
+
891
+ # (3) Slice the prediction from 'start_index' onwards
892
+ temp_predicted = predicted[start_index:]
893
+
894
+ # (4) Find the earliest occurrence of any special token to truncate
895
+ matching_indices = []
896
+ for token_id in special_token_ids:
897
+ idx = np.where(temp_predicted == token_id)[0]
898
+ if idx.size > 0:
899
+ matching_indices.append(idx)
900
+
901
+ if matching_indices:
902
+ # Merge all matching indices and take the smallest
903
+ all_indices = np.concatenate(matching_indices)
904
+ first_match_index = np.min(all_indices)
905
+ # Truncate up to the first special token
906
+ temp_predicted = temp_predicted[:first_match_index]
907
+
908
+ # Append the processed sequence to the results
909
+ processed_predicted_ids.append(temp_predicted)
910
+
911
+ # Decode the processed prediction IDs
912
+ decoded_preds = tokenizer.batch_decode(
913
+ processed_predicted_ids,
914
+ skip_special_tokens=True
915
+ )
916
+
917
+ # Filter labels by removing IGNORE_INDEX tokens
918
+ filtered_labels = [
919
+ [token for token in label_group if token != IGNORE_INDEX]
920
+ for label_group in labels
921
+ ]
922
+
923
+ decoded_labels = tokenizer.batch_decode(
924
+ filtered_labels,
925
+ skip_special_tokens=True
926
+ )
927
+
928
+ references = [[lbl] for lbl in decoded_labels]
929
+
930
+ # Calculate BLEU score
931
+ bleu_score = bleu_metric.compute(
932
+ predictions=decoded_preds,
933
+ references=references,
934
+ max_order=4
935
+ )["bleu"]
936
+
937
+ # Calculate ROUGE-L score
938
+ rouge_score = rouge_metric.compute(
939
+ predictions=decoded_preds,
940
+ references=references
941
+ )["rougeL"]
942
+
943
+ # Calculate Temporal-F1 score
944
+ tem_f1_score = temporal_f1_score(
945
+ predictions=decoded_preds,
946
+ references=references
947
+ )["f1"]
948
+
949
+ # Clean up memory
950
+ del logits, labels, decoded_preds, decoded_labels, references
951
+ torch.cuda.empty_cache()
952
+
953
+ # Return metrics
954
+ return {
955
+ "BLEU4": bleu_score,
956
+ "ROUGE-L": rouge_score,
957
+ "TEM-F1": tem_f1_score
958
+ }
959
+
960
+ return compute_metrics
961
+
962
+ def check_trainable_parameters(model: torch.nn.Module) -> None:
963
+ """
964
+ Print the names, shapes, and data types of all trainable parameters in the model.
965
+
966
+ Args:
967
+ model (torch.nn.Module): The model to inspect.
968
+ """
969
+ total_params = sum(
970
+ p.numel() for p in model.parameters() if p.requires_grad
971
+ )
972
+
973
+ print(f"Total number of trainable parameters: {total_params:,d}\n")
974
+
975
+ # (Optional) Print the model structure for reference
976
+ print("Overall model structure:")
977
+ print(model)
978
+ print("\nTrainable parameters:")
979
+
980
+ # Print details of each trainable parameter
981
+ for name, param in model.named_parameters():
982
+ if param.requires_grad:
983
+ param_info = (
984
+ f"Shape: {list(param.shape)}, "
985
+ f"Dtype: {param.dtype}"
986
+ )
987
+ print(f" - {name} -> {param_info}")
988
+
989
+ class LazySupervisedDataset(Dataset):
990
+ """Dataset for supervised fine-tuning."""
991
+
992
+ def __init__(self, data_path: str,
993
+ tokenizer: transformers.PreTrainedTokenizer,
994
+ data_args: DataArguments,
995
+ sample_rate=1.0):
996
+ super(LazySupervisedDataset, self).__init__()
997
+ list_data_dict = json.load(open(data_path, "r"))
998
+
999
+ # Apply sampling if sample_rate < 1.0
1000
+ if 0 < sample_rate < 1.0:
1001
+ random.seed(27) # Fixed seed for consistent behavior across different runs
1002
+ sampled_size = int(len(list_data_dict) * sample_rate)
1003
+ list_data_dict = random.sample(list_data_dict, sampled_size)
1004
+
1005
+ rank0_print("Formatting inputs...Skip in lazy mode")
1006
+ self.tokenizer = tokenizer
1007
+ self.list_data_dict = list_data_dict
1008
+ self.data_args = data_args
1009
+
1010
+ def __len__(self):
1011
+ return len(self.list_data_dict)
1012
+
1013
+ @property
1014
+ def lengths(self):
1015
+ length_list = []
1016
+ for sample in self.list_data_dict:
1017
+ img_tokens = 128 if 'image' in sample else 0
1018
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
1019
+ return length_list
1020
+
1021
+ @property
1022
+ def modality_lengths(self):
1023
+ length_list = []
1024
+ for sample in self.list_data_dict:
1025
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
1026
+ cur_len = cur_len if 'image' in sample else -cur_len
1027
+ length_list.append(cur_len)
1028
+ return length_list
1029
+
1030
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
1031
+ sources = self.list_data_dict[i]
1032
+ if isinstance(i, int):
1033
+ sources = [sources]
1034
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
1035
+ if 'image' in sources[0]:
1036
+ image_file = self.list_data_dict[i]['image']
1037
+ image_folder = self.data_args.image_folder
1038
+ processor = self.data_args.image_processor
1039
+
1040
+ if isinstance(image_file, str):
1041
+ image=[]
1042
+ image_path = os.path.join(image_folder, image_file)
1043
+ img = load_images(image_path)
1044
+ image.append(img)
1045
+ # set dummy prior image
1046
+ image.append(img)
1047
+
1048
+ elif isinstance(image_file, (list, tuple)):
1049
+ image=[]
1050
+ image_paths = [os.path.join(image_folder, file_name) for file_name in image_file]
1051
+ for path in image_paths:
1052
+ img = load_images(path)
1053
+ image.append(img)
1054
+ # set dummy prior image
1055
+ if len(image) == 1:
1056
+ print("Contains only current image. Adding a dummy prior image.")
1057
+ image.append(image[0])
1058
+
1059
+ else:
1060
+ raise TypeError("image_file must be a string or a list/tuple of strings")
1061
+
1062
+ if self.data_args.image_aspect_ratio == 'pad':
1063
+ def expand2square(pil_img, background_color=(0, 0, 0)):
1064
+ width, height = pil_img.size
1065
+ if width == height:
1066
+ return pil_img
1067
+ elif width > height:
1068
+ result = Image.new(pil_img.mode, (width, width), background_color)
1069
+ result.paste(pil_img, (0, (width - height) // 2))
1070
+ return result
1071
+ else:
1072
+ result = Image.new(pil_img.mode, (height, height), background_color)
1073
+ result.paste(pil_img, ((height - width) // 2, 0))
1074
+ return result
1075
+
1076
+ processed_images = []
1077
+ for img_data in image:
1078
+ pad_image = expand2square(img_data, (0, 0, 0))
1079
+ image_temp = processor.preprocess(pad_image, return_tensors='pt')['pixel_values'][0]
1080
+ processed_images.append(image_temp)
1081
+ image = processed_images
1082
+
1083
+ else:
1084
+ processed_images = []
1085
+ for img_data in image:
1086
+ image_temp = processor.preprocess(img_data, return_tensors='pt')['pixel_values'][0]
1087
+ processed_images.append(image_temp)
1088
+ image = processed_images
1089
+
1090
+ sources = preprocess_multimodal(
1091
+ copy.deepcopy([e["conversations"] for e in sources]),
1092
+ self.data_args)
1093
+ else:
1094
+ sources = copy.deepcopy([e["conversations"] for e in sources])
1095
+
1096
+ data_dict = preprocess(
1097
+ sources,
1098
+ self.tokenizer,
1099
+ has_image=('image' in self.list_data_dict[i]))
1100
+ if isinstance(i, int):
1101
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
1102
+ labels=data_dict["labels"][0])
1103
+
1104
+ # image exist in the data
1105
+ if 'image' in self.list_data_dict[i]:
1106
+ data_dict['image'] = image
1107
+ elif self.data_args.is_multimodal:
1108
+ # image does not exist in the data, but the model is multimodal
1109
+ crop_size = self.data_args.image_processor.crop_size
1110
+ data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
1111
+ return data_dict
1112
+
1113
+
1114
+ @dataclass
1115
+ class DataCollatorForSupervisedDataset(object):
1116
+ """Collate examples for supervised fine-tuning."""
1117
+
1118
+ tokenizer: transformers.PreTrainedTokenizer
1119
+
1120
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
1121
+ input_ids, labels = tuple([instance[key] for instance in instances]
1122
+ for key in ("input_ids", "labels"))
1123
+ input_ids = torch.nn.utils.rnn.pad_sequence(
1124
+ input_ids,
1125
+ batch_first=True,
1126
+ padding_value=self.tokenizer.pad_token_id)
1127
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
1128
+ batch_first=True,
1129
+ padding_value=IGNORE_INDEX)
1130
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
1131
+ labels = labels[:, :self.tokenizer.model_max_length]
1132
+ batch = dict(
1133
+ input_ids=input_ids,
1134
+ labels=labels,
1135
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
1136
+ )
1137
+
1138
+ if 'image' in instances[0]:
1139
+
1140
+ if not all(len(instance['image']) == 2 for instance in instances):
1141
+ raise ValueError("Each instance['image'] must contain exactly two type images.")
1142
+
1143
+ cur_images = [instance['image'][0] for instance in instances]
1144
+ prior_images = [instance['image'][1] for instance in instances]
1145
+
1146
+
1147
+ if all(x is not None and x.shape == cur_images[0].shape for x in cur_images) and \
1148
+ all(x is not None and x.shape == prior_images[0].shape for x in prior_images):
1149
+
1150
+ batch['images'] = torch.stack([torch.stack(cur_images), torch.stack(prior_images)])
1151
+ else:
1152
+ print("Warning: Image shapes are inconsistent. Using lists for images.")
1153
+ batch['images'] = [cur_images, prior_images]
1154
+
1155
+ return batch
1156
+
1157
+
1158
+
1159
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
1160
+ data_args) -> Dict:
1161
+ """Make dataset and collator for supervised fine-tuning."""
1162
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
1163
+ data_path=data_args.data_path,
1164
+ data_args=data_args)
1165
+
1166
+ eval_dataset = LazySupervisedDataset(tokenizer=tokenizer,
1167
+ data_path=data_args.validation_data_path,
1168
+ data_args=data_args,
1169
+ sample_rate=1.0)
1170
+
1171
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
1172
+ return dict(train_dataset=train_dataset,
1173
+ eval_dataset=eval_dataset,
1174
+ data_collator=data_collator)
1175
+
1176
+ def train():
1177
+ global local_rank
1178
+
1179
+ parser = transformers.HfArgumentParser(
1180
+ (ModelArguments, DataArguments, TrainingArguments))
1181
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
1182
+ local_rank = training_args.local_rank
1183
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
1184
+
1185
+ bnb_model_from_pretrained_args = {}
1186
+ if training_args.bits in [4, 8]:
1187
+ from transformers import BitsAndBytesConfig
1188
+ bnb_model_from_pretrained_args.update(dict(
1189
+ device_map={"": training_args.device},
1190
+ load_in_4bit=training_args.bits == 4,
1191
+ load_in_8bit=training_args.bits == 8,
1192
+ quantization_config=BitsAndBytesConfig(
1193
+ load_in_4bit=training_args.bits == 4,
1194
+ load_in_8bit=training_args.bits == 8,
1195
+ llm_int8_skip_modules=["mm_projector"],
1196
+ llm_int8_threshold=6.0,
1197
+ llm_int8_has_fp16_weight=False,
1198
+ bnb_4bit_compute_dtype=compute_dtype,
1199
+ bnb_4bit_use_double_quant=training_args.double_quant,
1200
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
1201
+ )
1202
+ ))
1203
+
1204
+ if model_args.vision_tower is not None:
1205
+ model = LibraLlamaForCausalLM.from_pretrained(
1206
+ model_args.model_name_or_path,
1207
+ cache_dir=training_args.cache_dir,
1208
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
1209
+ **bnb_model_from_pretrained_args
1210
+ )
1211
+ else:
1212
+ model = transformers.LlamaForCausalLM.from_pretrained(
1213
+ model_args.model_name_or_path,
1214
+ cache_dir=training_args.cache_dir,
1215
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
1216
+ **bnb_model_from_pretrained_args
1217
+ )
1218
+ model.config.use_cache = False
1219
+
1220
+ if model_args.freeze_backbone:
1221
+ model.model.requires_grad_(False)
1222
+
1223
+ if training_args.bits in [4, 8]:
1224
+ from peft import prepare_model_for_kbit_training
1225
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
1226
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
1227
+
1228
+ if training_args.gradient_checkpointing:
1229
+ if hasattr(model, "enable_input_require_grads"):
1230
+ model.enable_input_require_grads()
1231
+ else:
1232
+ def make_inputs_require_grad(module, input, output):
1233
+ output.requires_grad_(True)
1234
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
1235
+
1236
+ if training_args.lora_enable:
1237
+ from peft import LoraConfig, get_peft_model
1238
+ lora_config = LoraConfig(
1239
+ r=training_args.lora_r,
1240
+ lora_alpha=training_args.lora_alpha,
1241
+ target_modules=find_all_linear_names(model),
1242
+ lora_dropout=training_args.lora_dropout,
1243
+ bias=training_args.lora_bias,
1244
+ task_type="CAUSAL_LM",
1245
+ )
1246
+ if training_args.bits == 16:
1247
+ if training_args.bf16:
1248
+ model.to(torch.bfloat16)
1249
+ if training_args.fp16:
1250
+ model.to(torch.float16)
1251
+ rank0_print("Adding LoRA adapters...")
1252
+ model = get_peft_model(model, lora_config)
1253
+
1254
+ if 'mpt' in model_args.model_name_or_path:
1255
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
1256
+ model_args.model_name_or_path,
1257
+ cache_dir=training_args.cache_dir,
1258
+ model_max_length=training_args.model_max_length,
1259
+ padding_side="right"
1260
+ )
1261
+ else:
1262
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
1263
+ model_args.model_name_or_path,
1264
+ cache_dir=training_args.cache_dir,
1265
+ model_max_length=training_args.model_max_length,
1266
+ padding_side="right",
1267
+ use_fast=False,
1268
+ )
1269
+
1270
+ if model_args.version == "v0":
1271
+ if tokenizer.pad_token is None:
1272
+ smart_tokenizer_and_embedding_resize(
1273
+ special_tokens_dict=dict(pad_token="[PAD]"),
1274
+ tokenizer=tokenizer,
1275
+ model=model,
1276
+ )
1277
+
1278
+ elif model_args.version == "v0.5":
1279
+ tokenizer.pad_token = tokenizer.unk_token
1280
+ else:
1281
+ tokenizer.pad_token = tokenizer.unk_token
1282
+ if model_args.version in conversation_lib.conv_templates:
1283
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
1284
+ else:
1285
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
1286
+
1287
+ if model_args.vision_tower is not None:
1288
+ model.get_model().initialize_vision_modules(
1289
+ model_args=model_args,
1290
+ fsdp=training_args.fsdp
1291
+ )
1292
+
1293
+ vision_tower = model.get_vision_tower()
1294
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
1295
+
1296
+ data_args.image_processor = vision_tower.image_processor
1297
+ data_args.is_multimodal = True
1298
+
1299
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
1300
+ model.config.tokenizer_padding_side = tokenizer.padding_side
1301
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
1302
+
1303
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
1304
+ if model_args.tune_mm_mlp_adapter:
1305
+ model.requires_grad_(False)
1306
+ for p in model.get_model().mm_projector.parameters():
1307
+ p.requires_grad = True
1308
+
1309
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
1310
+ if training_args.freeze_mm_mlp_adapter:
1311
+ for p in model.get_model().mm_projector.parameters():
1312
+ p.requires_grad = False
1313
+
1314
+ if training_args.bits in [4, 8]:
1315
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
1316
+
1317
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
1318
+ model.config.mm_projector_lr = training_args.mm_projector_lr
1319
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
1320
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
1321
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
1322
+
1323
+ if training_args.bits in [4, 8]:
1324
+ from peft.tuners.lora import LoraLayer
1325
+ for name, module in model.named_modules():
1326
+ if isinstance(module, LoraLayer):
1327
+ if training_args.bf16:
1328
+ module = module.to(torch.bfloat16)
1329
+ if 'norm' in name:
1330
+ module = module.to(torch.float32)
1331
+ if 'lm_head' in name or 'embed_tokens' in name:
1332
+ if hasattr(module, 'weight'):
1333
+ if training_args.bf16 and module.weight.dtype == torch.float32:
1334
+ module = module.to(torch.bfloat16)
1335
+
1336
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
1337
+ data_args=data_args)
1338
+
1339
+
1340
+ # SaveCallback
1341
+ class SaveCallback(TrainerCallback):
1342
+
1343
+ def __init__(self):
1344
+ super().__init__()
1345
+ self.best_metric = None
1346
+
1347
+ def on_evaluate(self, args, state, control, metrics, **kwargs):
1348
+ """
1349
+ Custom logic for evaluating and saving the best model based on a chosen metric.
1350
+
1351
+ Saves the model and configuration if a better metric is achieved during evaluation.
1352
+ """
1353
+ metric_for_best_model = 'eval_loss' # Metric used to determine the best model (e.g., eval_loss, eval_bleu, eval_rouge)
1354
+ metric_value = metrics.get(metric_for_best_model)
1355
+
1356
+ if self.best_metric is None or metric_value < self.best_metric:
1357
+ self.best_metric = metric_value
1358
+ best_model_dir = os.path.join(args.output_dir, 'best_eval_model')
1359
+
1360
+ # Save generation configuration if present
1361
+ if hasattr(model, 'generation_config'):
1362
+ model.generation_config.save_pretrained(best_model_dir)
1363
+
1364
+ # Save model configuration
1365
+ model.config.save_pretrained(best_model_dir)
1366
+
1367
+ if tokenizer is not None:
1368
+ tokenizer.save_pretrained(best_model_dir)
1369
+
1370
+ # Save the best model
1371
+ if args.lora_enable:
1372
+ # Save LoRA-specific parameters
1373
+ state_dict = get_peft_state_maybe_zero_3(
1374
+ model.named_parameters(), args.lora_bias
1375
+ )
1376
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
1377
+ model.named_parameters()
1378
+ )
1379
+ if args.local_rank in [-1, 0]:
1380
+ model.save_pretrained(best_model_dir, state_dict=state_dict)
1381
+ torch.save(non_lora_state_dict, os.path.join(best_model_dir, 'non_lora_trainables.bin'))
1382
+ else:
1383
+ # Save full model state when not using LoRA
1384
+ state_dict = get_non_vision_tower_state_maybe_zero_3(
1385
+ model.named_parameters()
1386
+ )
1387
+ if args.local_rank in [-1, 0]:
1388
+ model.save_pretrained(best_model_dir, state_dict=state_dict)
1389
+ # Save mm_projector state when tuning mm_mlp_adapter
1390
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=best_model_dir)
1391
+
1392
+ check_trainable_parameters(model)
1393
+
1394
+ compute_metrics_func = None
1395
+
1396
+ if model_args.compute_metrics:
1397
+ compute_metrics_func = create_compute_metrics(tokenizer,vision_tower.num_patches,conversation_lib.default_conversation.sep2)
1398
+
1399
+ model.to(training_args.device)
1400
+
1401
+ trainer = LibraTrainer(model=model,
1402
+ tokenizer=tokenizer,
1403
+ args=training_args,
1404
+ callbacks=[SaveCallback()],
1405
+ compute_metrics=compute_metrics_func,
1406
+ **data_module)
1407
+
1408
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
1409
+ trainer.train(resume_from_checkpoint=True)
1410
+ else:
1411
+ trainer.train()
1412
+
1413
+ trainer.save_state()
1414
+
1415
+ model.config.use_cache = True
1416
+
1417
+ if training_args.lora_enable:
1418
+ state_dict = get_peft_state_maybe_zero_3(
1419
+ model.named_parameters(), training_args.lora_bias
1420
+ )
1421
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
1422
+ model.named_parameters()
1423
+ )
1424
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
1425
+ model.config.save_pretrained(training_args.output_dir)
1426
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
1427
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
1428
+ else:
1429
+ safe_save_model_for_hf_trainer(trainer=trainer,
1430
+ output_dir=training_args.output_dir)
1431
+
1432
+
1433
+ if __name__ == "__main__":
1434
+ train()
libra/train/train_mem.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4
+
5
+ # # Need to call this before importing transformers.
6
+
7
+ # from libra.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
8
+ # replace_llama_attn_with_flash_attn()
9
+
10
+ from libra.train.llama2_flash_attn_monkey_patch import (
11
+ replace_llama_attn_with_flash_attn,
12
+ )
13
+
14
+ replace_llama_attn_with_flash_attn()
15
+
16
+ from libra.train.train import train
17
+
18
+ if __name__ == "__main__":
19
+ train()
libra/train/train_xformers.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
4
+
5
+ # Need to call this before importing transformers.
6
+ from libra.train.llama_xformers_attn_monkey_patch import (
7
+ replace_llama_attn_with_xformers_attn,
8
+ )
9
+
10
+ replace_llama_attn_with_xformers_attn()
11
+
12
+ from libra.train.train import train
13
+
14
+ if __name__ == "__main__":
15
+ train()
libra/utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from libra.constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+ def disable_torch_init():
18
+ """
19
+ Disable the redundant torch default initialization to accelerate model creation.
20
+ """
21
+ import torch
22
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
23
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
24
+