KingNish commited on
Commit
5d99e45
·
verified ·
1 Parent(s): 6417a68

Update modeling/bagel/bagel.py

Browse files
Files changed (1) hide show
  1. modeling/bagel/bagel.py +1039 -1025
modeling/bagel/bagel.py CHANGED
@@ -1,1026 +1,1040 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import copy
5
- from typing import List, Tuple, Optional
6
- import matplotlib.pyplot as plt
7
-
8
- from PIL import Image
9
- import torch
10
- import torch.nn.functional as F
11
- from torch import nn
12
- from torch.nn.attention.flex_attention import create_block_mask
13
- from transformers.configuration_utils import PretrainedConfig
14
- from transformers.modeling_utils import PreTrainedModel
15
-
16
- from data.data_utils import (
17
- create_sparse_mask,
18
- get_flattened_position_ids_extrapolate,
19
- get_flattened_position_ids_interpolate,
20
- patchify,
21
- )
22
- from .qwen2_navit import NaiveCache
23
- from .modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding
24
-
25
-
26
- class BagelConfig(PretrainedConfig):
27
- def __init__(
28
- self,
29
- visual_gen=True,
30
- visual_und=True,
31
- llm_config=None,
32
- vit_config=None,
33
- vae_config=None,
34
- latent_patch_size=2,
35
- max_latent_size=32,
36
- vit_max_num_patch_per_side=70,
37
- connector_act="gelu_pytorch_tanh",
38
- interpolate_pos=False,
39
- timestep_shift=1.0,
40
- **kwargs
41
- ):
42
- super().__init__(**kwargs)
43
- self.visual_gen = visual_gen
44
- self.visual_und = visual_und
45
- self.llm_config = llm_config
46
- self.vit_config = vit_config
47
- self.vae_config = vae_config
48
- self.latent_patch_size = latent_patch_size
49
- self.max_latent_size = max_latent_size
50
- self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
51
- self.connector_act = connector_act
52
- self.interpolate_pos = interpolate_pos
53
- self.timestep_shift = timestep_shift
54
-
55
-
56
- class Bagel(PreTrainedModel):
57
- config_class = BagelConfig
58
- base_model_prefix = 'bagel'
59
-
60
- def __init__(self, language_model, vit_model, config: BagelConfig):
61
- super().__init__(config)
62
- self.language_model = language_model
63
- self.hidden_size = config.llm_config.hidden_size
64
- self.use_moe = "Mo" in config.llm_config.layer_module
65
- self.num_heads = config.llm_config.num_attention_heads
66
-
67
- if config.visual_gen:
68
- self.latent_patch_size = config.latent_patch_size
69
- self.timestep_shift = config.timestep_shift
70
- self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
71
- self.max_latent_size = config.max_latent_size
72
- self.latent_channel = config.vae_config.z_channels
73
- self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel
74
- self.time_embedder = TimestepEmbedder(self.hidden_size)
75
- self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
76
- self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
77
- self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
78
-
79
- if config.visual_und:
80
- self.vit_model = vit_model
81
- self.vit_patch_size = config.vit_config.patch_size
82
- self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
83
- self.vit_hidden_size = config.vit_config.hidden_size
84
- self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act)
85
- self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size)
86
-
87
- if config.interpolate_pos:
88
- self.get_flattened_position_ids = get_flattened_position_ids_interpolate
89
- else:
90
- self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
91
-
92
- self.config = config
93
- self._init_weights()
94
-
95
- def _init_weights(self):
96
- if self.config.visual_gen:
97
- nn.init.constant_(self.llm2vae.weight, 0)
98
- nn.init.constant_(self.llm2vae.bias, 0)
99
-
100
- def forward(
101
- self,
102
- sequence_length: int,
103
- packed_text_ids: torch.LongTensor,
104
- packed_text_indexes: torch.LongTensor,
105
- sample_lens: List[int],
106
- packed_position_ids: torch.LongTensor,
107
- nested_attention_masks: List[torch.Tensor] = None,
108
- split_lens: List[int] = None,
109
- attn_modes: List[str] = None,
110
- # for visual understanding
111
- ce_loss_indexes: Optional[torch.BoolTensor] = None,
112
- packed_label_ids: Optional[torch.LongTensor] = None,
113
- packed_vit_tokens: Optional[torch.Tensor] = None,
114
- packed_vit_token_indexes: Optional[torch.LongTensor] = None,
115
- packed_vit_position_ids: Optional[torch.LongTensor] = None,
116
- vit_token_seqlens: Optional[torch.IntTensor] = None,
117
- # for visual generation
118
- padded_latent: Optional[torch.Tensor] = None,
119
- patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
120
- packed_latent_position_ids: Optional[torch.LongTensor] = None,
121
- packed_vae_token_indexes: Optional[torch.LongTensor] = None,
122
- packed_timesteps: Optional[torch.LongTensor] = None,
123
- mse_loss_indexes: Optional[torch.BoolTensor] = None,
124
- ) -> torch.Tensor:
125
- """
126
- Args:
127
- sequence_length: length of sequence.
128
- packed_text_ids: 1-D int tensor, packed text token ids.
129
- packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
130
- sample_lens: A list of N ints, length of each sample in packed_sequence.
131
- nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
132
- -inf means ignore.
133
- packed_position_ids: packed 1-D positions, an image has only one global position shared
134
- by all latent tokens.
135
-
136
- packed_vit_tokens: packed patchified image tokens for vit model.
137
- packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
138
- packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
139
- vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
140
- packed_label_ids: 1-D int tensor, packed label token ids.
141
- ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
142
-
143
- padded_latent: padded latent from VAE encoder.
144
- patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
145
- packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
146
- packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
147
- packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
148
- mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
149
- """
150
- packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
151
- packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size))
152
- packed_sequence[packed_text_indexes] = packed_text_embedding
153
-
154
- if nested_attention_masks is None:
155
- sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device)
156
- seqlen = sum(sample_lens)
157
- block_mask = create_block_mask(
158
- sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen,
159
- device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True
160
- )
161
- attention_mask = block_mask
162
- else:
163
- attention_mask = nested_attention_masks
164
-
165
- if self.config.visual_und:
166
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
167
- cu_seqlens = cu_seqlens.to(torch.int32)
168
- max_seqlen = torch.max(vit_token_seqlens).item()
169
- packed_vit_token_embed = self.vit_model(
170
- packed_pixel_values=packed_vit_tokens,
171
- packed_flattened_position_ids=packed_vit_position_ids,
172
- cu_seqlens=cu_seqlens,
173
- max_seqlen=max_seqlen,
174
- )
175
- packed_vit_token_embed = self.connector(packed_vit_token_embed)
176
- vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids)
177
- packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb
178
- packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
179
-
180
- if self.config.visual_gen:
181
- p = self.latent_patch_size
182
- packed_latent = []
183
- for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
184
- latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
185
- latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
186
- packed_latent.append(latent)
187
- packed_latent_clean = torch.cat(packed_latent, dim=0)
188
-
189
- noise = torch.randn_like(packed_latent_clean)
190
- packed_timesteps = torch.sigmoid(packed_timesteps)
191
- packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps)
192
- packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise
193
- packed_timestep_embeds = self.time_embedder(packed_timesteps)
194
- latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids)
195
- packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb
196
- packed_sequence[packed_vae_token_indexes] = packed_latent
197
-
198
- extra_inputs = {}
199
- if self.use_moe:
200
- packed_und_token_indexes = packed_text_indexes
201
- if packed_vit_token_indexes is not None:
202
- packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0)
203
- extra_inputs.update(
204
- packed_und_token_indexes=packed_und_token_indexes,
205
- packed_gen_token_indexes=packed_vae_token_indexes,
206
- )
207
-
208
- last_hidden_state = self.language_model(
209
- packed_sequence=packed_sequence,
210
- sample_lens=sample_lens,
211
- attention_mask=attention_mask,
212
- packed_position_ids=packed_position_ids,
213
- **extra_inputs,
214
- )
215
-
216
- mse = None
217
- if self.config.visual_gen:
218
- packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes])
219
- target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
220
- has_mse = packed_timesteps > 0
221
- mse = (packed_mse_preds - target[has_mse]) ** 2
222
-
223
- ce = None
224
- if ce_loss_indexes is not None:
225
- packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes])
226
- ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none")
227
-
228
- return dict(mse=mse, ce=ce)
229
-
230
-
231
- def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
232
- packed_text_ids = list()
233
- packed_text_position_ids = list()
234
- text_token_lens = list()
235
- packed_text_indexes = list()
236
- packed_key_value_indexes = list()
237
-
238
- curr = 0
239
- newlens, new_rope = list(), list()
240
- for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope):
241
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
242
- curr += curr_kvlen
243
-
244
- text_ids = tokenizer.encode(prompt)
245
- text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']]
246
- text_token_lens.append(len(text_ids))
247
- packed_text_ids.extend(text_ids)
248
- packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids)))
249
- packed_text_indexes.extend(range(curr, curr + len(text_ids)))
250
- newlens.append(curr_kvlen + len(text_ids))
251
- new_rope.append(curr_position_id + len(text_ids))
252
- curr += len(text_ids)
253
-
254
- generation_input = {
255
- "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
256
- "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
257
- "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long),
258
- "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
259
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
260
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
261
- }
262
-
263
- return generation_input, newlens, new_rope
264
-
265
- @torch.no_grad
266
- def forward_cache_update_text(
267
- self,
268
- past_key_values: NaiveCache,
269
- packed_text_ids: torch.IntTensor,
270
- packed_text_position_ids: torch.LongTensor,
271
- text_token_lens: torch.LongTensor,
272
- packed_text_indexes: torch.LongTensor,
273
- packed_key_value_indexes: torch.LongTensor,
274
- key_values_lens: torch.IntTensor,
275
- ):
276
- packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
277
-
278
- extra_inputs = {}
279
- if self.use_moe:
280
- extra_inputs = {"mode": "und"}
281
-
282
- output = self.language_model.forward_inference(
283
- packed_query_sequence=packed_text_embedding,
284
- query_lens=text_token_lens,
285
- packed_query_position_ids=packed_text_position_ids,
286
- packed_query_indexes=packed_text_indexes,
287
- past_key_values=past_key_values,
288
- packed_key_value_indexes=packed_key_value_indexes,
289
- key_values_lens=key_values_lens,
290
- update_past_key_values=True,
291
- is_causal=True,
292
- **extra_inputs,
293
- )
294
- past_key_values = output.past_key_values
295
-
296
- return past_key_values
297
-
298
- def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids):
299
- packed_vit_token_indexes = list()
300
- vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list()
301
- packed_text_ids, packed_text_indexes = list(), list()
302
- packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
303
- packed_key_value_indexes = list()
304
-
305
- _curr = curr = 0
306
- newlens, new_rope = list(), list()
307
- for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
308
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
309
- curr += curr_kvlen
310
-
311
- packed_text_ids.append(new_token_ids['start_of_image'])
312
- packed_text_indexes.append(_curr)
313
- packed_indexes.append(curr)
314
- curr += 1
315
- _curr += 1
316
-
317
- image_tensor = transforms(image)
318
- vit_position_ids = self.get_flattened_position_ids(
319
- image_tensor.size(1), image_tensor.size(2),
320
- self.vit_patch_size,
321
- max_num_patches_per_side=self.vit_max_num_patch_per_side
322
- )
323
- vit_tokens = patchify(image_tensor, self.vit_patch_size)
324
- packed_vit_tokens.append(vit_tokens)
325
- num_img_tokens = vit_tokens.shape[0]
326
- packed_vit_position_ids.append(vit_position_ids)
327
- vit_token_seqlens.append(num_img_tokens)
328
- packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
329
- packed_indexes.extend(range(curr, curr + num_img_tokens))
330
- curr += num_img_tokens
331
- _curr += num_img_tokens
332
-
333
- packed_text_ids.append(new_token_ids['end_of_image'])
334
- packed_text_indexes.append(_curr)
335
- packed_indexes.append(curr)
336
- curr += 1
337
- _curr += 1
338
-
339
- packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
340
- packed_seqlens.append(num_img_tokens + 2)
341
- newlens.append(curr_kvlen + num_img_tokens + 2)
342
- new_rope.append(curr_position_id + 1)
343
-
344
- generation_input = {
345
- "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
346
- "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
347
- "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
348
- "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
349
- "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
350
- "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long),
351
- "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
352
- "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
353
- "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
354
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
355
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
356
- }
357
-
358
- return generation_input, newlens, new_rope
359
-
360
- @torch.no_grad
361
- def forward_cache_update_vit(
362
- self,
363
- past_key_values: NaiveCache,
364
- packed_text_ids: torch.LongTensor,
365
- packed_text_indexes: torch.LongTensor,
366
- packed_vit_tokens: torch.Tensor,
367
- packed_vit_token_indexes: torch.LongTensor,
368
- packed_vit_position_ids: torch.LongTensor,
369
- vit_token_seqlens: torch.IntTensor,
370
- packed_position_ids: torch.LongTensor,
371
- packed_seqlens: torch.IntTensor,
372
- packed_indexes: torch.LongTensor,
373
- packed_key_value_indexes: torch.LongTensor,
374
- key_values_lens: torch.IntTensor,
375
- ):
376
- packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
377
- packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
378
- packed_sequence[packed_text_indexes] = packed_text_embedding
379
-
380
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
381
- cu_seqlens = cu_seqlens.to(torch.int32)
382
- max_seqlen = torch.max(vit_token_seqlens).item()
383
- packed_vit_token_embed = self.vit_model(
384
- packed_pixel_values=packed_vit_tokens,
385
- packed_flattened_position_ids=packed_vit_position_ids,
386
- cu_seqlens=cu_seqlens,
387
- max_seqlen=max_seqlen,
388
- )
389
- packed_vit_token_embed = self.connector(packed_vit_token_embed)
390
- pos_emb = self.vit_pos_embed(packed_vit_position_ids)
391
- packed_vit_token_embed = packed_vit_token_embed + pos_emb
392
- packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
393
-
394
- extra_inputs = {}
395
- if self.use_moe:
396
- extra_inputs = {"mode": "und"}
397
-
398
- output = self.language_model.forward_inference(
399
- packed_query_sequence=packed_sequence,
400
- query_lens=packed_seqlens,
401
- packed_query_position_ids=packed_position_ids,
402
- packed_query_indexes=packed_indexes,
403
- past_key_values=past_key_values,
404
- packed_key_value_indexes=packed_key_value_indexes,
405
- key_values_lens=key_values_lens,
406
- update_past_key_values=True,
407
- is_causal=False,
408
- **extra_inputs,
409
- )
410
- past_key_values = output.past_key_values
411
-
412
- return past_key_values
413
-
414
- def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0):
415
- patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
416
- packed_vae_token_indexes = list()
417
- packed_text_ids, packed_text_indexes = list(), list()
418
- packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
419
- packed_key_value_indexes = list()
420
-
421
- _curr = curr = 0
422
- vae_image_tensors = list()
423
- newlens, new_rope = list(), list()
424
- for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
425
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
426
- curr += curr_kvlen
427
-
428
- packed_text_ids.append(new_token_ids['start_of_image'])
429
- packed_text_indexes.append(_curr)
430
- packed_indexes.append(curr)
431
- curr += 1
432
- _curr += 1
433
-
434
- image_tensor = transforms(image)
435
- vae_image_tensors.append(image_tensor)
436
- vae_posiiton_ids = self.get_flattened_position_ids(
437
- image_tensor.size(1), image_tensor.size(2),
438
- self.latent_downsample,
439
- max_num_patches_per_side=self.max_latent_size
440
- )
441
- packed_vae_position_ids.append(vae_posiiton_ids)
442
- H, W = image_tensor.shape[1:]
443
- h = H // self.latent_downsample
444
- w = W // self.latent_downsample
445
- patchified_vae_latent_shapes.append((h, w))
446
-
447
- num_img_tokens = w * h
448
- packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
449
- packed_indexes.extend(range(curr, curr + num_img_tokens))
450
- curr += num_img_tokens
451
- _curr += num_img_tokens
452
-
453
- packed_text_ids.append(new_token_ids['end_of_image'])
454
- packed_text_indexes.append(_curr)
455
- packed_indexes.append(curr)
456
- curr += 1
457
- _curr += 1
458
-
459
- packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
460
- packed_seqlens.append(num_img_tokens + 2)
461
- newlens.append(curr_kvlen + num_img_tokens + 2)
462
- new_rope.append(curr_position_id + 1)
463
-
464
- image_sizes = [item.shape for item in vae_image_tensors]
465
- max_image_size = [max(item) for item in list(zip(*image_sizes))]
466
- padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
467
- for i, image_tensor in enumerate(vae_image_tensors):
468
- padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
469
-
470
- generation_input = {
471
- "padded_images": padded_images,
472
- "patchified_vae_latent_shapes": patchified_vae_latent_shapes,
473
- "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
474
- "packed_timesteps": torch.tensor([timestep]),
475
- "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
476
- "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
477
- "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
478
- "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
479
- "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
480
- "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
481
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
482
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
483
- }
484
-
485
- return generation_input, newlens, new_rope
486
-
487
- @torch.no_grad
488
- def forward_cache_update_vae(
489
- self,
490
- vae_model,
491
- past_key_values: NaiveCache,
492
- padded_images: torch.Tensor,
493
- patchified_vae_latent_shapes: List,
494
- packed_vae_position_ids: torch.LongTensor,
495
- packed_timesteps: torch.Tensor,
496
- packed_vae_token_indexes: torch.LongTensor,
497
- packed_text_ids: torch.LongTensor,
498
- packed_text_indexes: torch.LongTensor,
499
- packed_position_ids: torch.LongTensor,
500
- packed_seqlens: torch.IntTensor,
501
- packed_indexes: torch.LongTensor,
502
- key_values_lens: torch.IntTensor,
503
- packed_key_value_indexes: torch.Tensor,
504
- ):
505
- packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
506
- packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
507
- packed_sequence[packed_text_indexes] = packed_text_embedding
508
-
509
- padded_latent = vae_model.encode(padded_images)
510
-
511
- p = self.latent_patch_size
512
- packed_latent = list()
513
- for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
514
- latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
515
- latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
516
- packed_latent.append(latent)
517
- packed_latent = torch.cat(packed_latent, dim=0)
518
- packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
519
- packed_timestep_embeds = self.time_embedder(packed_timesteps)
520
- packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
521
- packed_sequence[packed_vae_token_indexes] = packed_latent
522
-
523
- extra_inputs = {}
524
- if self.use_moe:
525
- extra_inputs = {
526
- "mode": "gen",
527
- "packed_vae_token_indexes": packed_vae_token_indexes,
528
- "packed_text_indexes": packed_text_indexes
529
- }
530
-
531
- output = self.language_model.forward_inference(
532
- packed_query_sequence=packed_sequence,
533
- query_lens=packed_seqlens,
534
- packed_query_position_ids=packed_position_ids,
535
- packed_query_indexes=packed_indexes,
536
- past_key_values=past_key_values,
537
- key_values_lens=key_values_lens,
538
- packed_key_value_indexes=packed_key_value_indexes,
539
- update_past_key_values=True,
540
- is_causal=False,
541
- **extra_inputs,
542
- )
543
- past_key_values = output.past_key_values
544
-
545
- return past_key_values
546
-
547
- def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
548
- packed_text_ids, packed_text_indexes = list(), list()
549
- packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list()
550
- packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list()
551
- packed_key_value_indexes = list()
552
-
553
- query_curr = curr = 0
554
- for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
555
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
556
- curr += curr_kvlen
557
-
558
- packed_text_ids.append(new_token_ids['start_of_image'])
559
- packed_text_indexes.append(query_curr)
560
- packed_indexes.append(curr)
561
- curr += 1
562
- query_curr += 1
563
-
564
- vae_posiiton_ids = self.get_flattened_position_ids(
565
- H, W,
566
- self.latent_downsample,
567
- max_num_patches_per_side=self.max_latent_size
568
- )
569
- packed_vae_position_ids.append(vae_posiiton_ids)
570
-
571
- h, w = H // self.latent_downsample, W // self.latent_downsample
572
- num_image_tokens = h * w
573
- packed_init_noises.append(
574
- torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2)
575
- )
576
- packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens))
577
- packed_indexes.extend(range(curr, curr + num_image_tokens))
578
- curr += num_image_tokens
579
- query_curr += num_image_tokens
580
-
581
- packed_text_ids.append(new_token_ids['end_of_image'])
582
- packed_text_indexes.append(query_curr)
583
- packed_indexes.append(curr)
584
- curr += 1
585
- query_curr += 1
586
-
587
- packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
588
- packed_seqlens.append(num_image_tokens + 2)
589
-
590
- generation_input = {
591
- "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
592
- "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
593
- "packed_init_noises": torch.cat(packed_init_noises, dim=0),
594
- "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
595
- "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
596
- "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
597
- "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
598
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
599
- "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
600
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
601
- }
602
-
603
- return generation_input
604
-
605
- def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
606
- packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list()
607
-
608
- query_curr = curr = 0
609
- for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
610
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
611
- curr += curr_kvlen
612
-
613
- packed_indexes.append(curr)
614
- curr += 1
615
- query_curr += 1
616
-
617
- h, w = H // self.latent_downsample, W // self.latent_downsample
618
- num_image_tokens = h * w
619
- packed_indexes.extend(range(curr, curr + num_image_tokens))
620
- curr += num_image_tokens
621
- query_curr += num_image_tokens
622
-
623
- packed_indexes.append(curr)
624
- curr += 1
625
- query_curr += 1
626
-
627
- packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
628
-
629
- generation_input = {
630
- "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
631
- "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
632
- "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
633
- "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
634
- }
635
-
636
- return generation_input
637
-
638
- @torch.no_grad
639
- def generate_image(
640
- self,
641
- packed_text_ids: torch.LongTensor,
642
- packed_text_indexes: torch.LongTensor,
643
- packed_init_noises: torch.Tensor,
644
- packed_vae_position_ids: torch.LongTensor,
645
- packed_vae_token_indexes: torch.LongTensor,
646
- packed_seqlens: torch.IntTensor,
647
- packed_position_ids: torch.LongTensor,
648
- packed_indexes: torch.LongTensor,
649
- past_key_values: NaiveCache,
650
- key_values_lens: torch.IntTensor,
651
- packed_key_value_indexes: torch.LongTensor,
652
- num_timesteps: int = 24,
653
- timestep_shift: float = 1.0,
654
- cfg_renorm_min: float = 0.0,
655
- cfg_renorm_type: str = "global",
656
- cfg_interval: Optional[Tuple[float, float]] = [0, 1],
657
- # cfg_text
658
- cfg_text_scale: float = 1.0,
659
- cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
660
- cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
661
- cfg_text_past_key_values: Optional[NaiveCache] = None,
662
- cfg_text_key_values_lens: Optional[torch.IntTensor] = None,
663
- cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
664
- # cfg_img
665
- cfg_img_scale: float = 1.0,
666
- cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
667
- cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
668
- cfg_img_past_key_values: Optional[NaiveCache] = None,
669
- cfg_img_key_values_lens: Optional[torch.IntTensor] = None,
670
- cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
671
- cfg_type: str = "parallel",
672
- ):
673
- x_t = packed_init_noises
674
-
675
- timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
676
- timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
677
- dts = timesteps[:-1] - timesteps[1:]
678
- timesteps = timesteps[:-1]
679
-
680
- for i, t in enumerate(timesteps):
681
-
682
- timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
683
- if t > cfg_interval[0] and t <= cfg_interval[1]:
684
- cfg_text_scale_ = cfg_text_scale
685
- cfg_img_scale_ = cfg_img_scale
686
- else:
687
- cfg_text_scale_ = 1.0
688
- cfg_img_scale_ = 1.0
689
- v_t = self._forward_flow(
690
- x_t=x_t,
691
- timestep=timestep,
692
- packed_vae_token_indexes=packed_vae_token_indexes,
693
- packed_vae_position_ids=packed_vae_position_ids,
694
- packed_text_ids=packed_text_ids,
695
- packed_text_indexes=packed_text_indexes,
696
- packed_position_ids=packed_position_ids,
697
- packed_indexes=packed_indexes,
698
- packed_seqlens=packed_seqlens,
699
- key_values_lens=key_values_lens,
700
- past_key_values=past_key_values,
701
- packed_key_value_indexes=packed_key_value_indexes,
702
- cfg_renorm_min=cfg_renorm_min,
703
- cfg_renorm_type=cfg_renorm_type,
704
- # cfg_text
705
- cfg_text_scale=cfg_text_scale_,
706
- cfg_text_packed_position_ids=cfg_text_packed_position_ids,
707
- cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
708
- cfg_text_key_values_lens=cfg_text_key_values_lens,
709
- cfg_text_past_key_values=cfg_text_past_key_values,
710
- cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
711
- # cfg_img
712
- cfg_img_scale=cfg_img_scale_,
713
- cfg_img_packed_position_ids=cfg_img_packed_position_ids,
714
- cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
715
- cfg_img_key_values_lens=cfg_img_key_values_lens,
716
- cfg_img_past_key_values=cfg_img_past_key_values,
717
- cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
718
- cfg_type=cfg_type,
719
- )
720
-
721
- x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise
722
-
723
- unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
724
- return unpacked_latent
725
-
726
- @torch.no_grad
727
- def _forward_flow(
728
- self,
729
- x_t: torch.Tensor,
730
- timestep: torch.LongTensor,
731
- packed_vae_token_indexes: torch.LongTensor,
732
- packed_vae_position_ids: torch.LongTensor,
733
- packed_text_ids: torch.LongTensor,
734
- packed_text_indexes: torch.LongTensor,
735
- packed_indexes: torch.LongTensor,
736
- packed_position_ids: torch.LongTensor,
737
- packed_seqlens: torch.IntTensor,
738
- key_values_lens: torch.IntTensor,
739
- past_key_values: NaiveCache,
740
- packed_key_value_indexes: torch.LongTensor,
741
- cfg_renorm_min: float = 0.0,
742
- cfg_renorm_type: str = "global",
743
- # cfg_text
744
- cfg_text_scale: float = 1.0,
745
- cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
746
- cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
747
- cfg_text_key_values_lens: Optional[torch.Tensor] = None,
748
- cfg_text_past_key_values: Optional[NaiveCache] = None,
749
- cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
750
- # cfg_img
751
- cfg_img_scale: float = 1.0,
752
- cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
753
- cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
754
- cfg_img_key_values_lens: Optional[torch.Tensor] = None,
755
- cfg_img_past_key_values: Optional[NaiveCache] = None,
756
- cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
757
- cfg_type: str = "parallel",
758
- ):
759
- packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
760
- packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
761
- packed_sequence[packed_text_indexes] = packed_text_embedding
762
-
763
- assert timestep.unique().shape[0] == 1
764
- packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
765
- packed_timestep_embeds = self.time_embedder(timestep)
766
- x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
767
- packed_sequence[packed_vae_token_indexes] = x_t
768
-
769
- extra_inputs = {}
770
- if self.use_moe:
771
- extra_inputs = {
772
- "mode": "gen",
773
- "packed_vae_token_indexes": packed_vae_token_indexes,
774
- "packed_text_indexes": packed_text_indexes
775
- }
776
-
777
- output = self.language_model.forward_inference(
778
- packed_query_sequence=packed_sequence,
779
- query_lens=packed_seqlens,
780
- packed_query_position_ids=packed_position_ids,
781
- packed_query_indexes=packed_indexes,
782
- past_key_values=past_key_values,
783
- key_values_lens=key_values_lens,
784
- packed_key_value_indexes=packed_key_value_indexes,
785
- update_past_key_values=False,
786
- is_causal=False,
787
- **extra_inputs,
788
- )
789
- v_t = self.llm2vae(output.packed_query_sequence)
790
- v_t = v_t[packed_vae_token_indexes]
791
-
792
- if cfg_text_scale > 1.0:
793
- cfg_text_output = self.language_model.forward_inference(
794
- packed_query_sequence=packed_sequence,
795
- query_lens=packed_seqlens,
796
- packed_query_position_ids=cfg_text_packed_position_ids,
797
- packed_query_indexes=cfg_text_packed_query_indexes,
798
- past_key_values=cfg_text_past_key_values,
799
- key_values_lens=cfg_text_key_values_lens,
800
- packed_key_value_indexes=cfg_text_packed_key_value_indexes,
801
- update_past_key_values=False,
802
- is_causal=False,
803
- **extra_inputs,
804
- )
805
- cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
806
- cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]
807
-
808
- if cfg_img_scale > 1.0:
809
- cfg_img_output = self.language_model.forward_inference(
810
- packed_query_sequence=packed_sequence,
811
- query_lens=packed_seqlens,
812
- packed_query_position_ids=cfg_img_packed_position_ids,
813
- packed_query_indexes=cfg_img_packed_query_indexes,
814
- past_key_values=cfg_img_past_key_values,
815
- key_values_lens=cfg_img_key_values_lens,
816
- packed_key_value_indexes=cfg_img_packed_key_value_indexes,
817
- update_past_key_values=False,
818
- is_causal=False,
819
- **extra_inputs,
820
- )
821
- cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
822
- cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]
823
-
824
- if cfg_text_scale > 1.0:
825
- if cfg_renorm_type == "text_channel":
826
- v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
827
- norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
828
- norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
829
- scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
830
- v_t_text = v_t_text_ * scale
831
- if cfg_img_scale > 1.0:
832
- v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
833
- else:
834
- v_t = v_t_text
835
- else:
836
- v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
837
-
838
- if cfg_img_scale > 1.0:
839
- v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
840
- else:
841
- v_t_ = v_t_text_
842
-
843
- # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
844
- if cfg_renorm_type == "global":
845
- norm_v_t = torch.norm(v_t)
846
- norm_v_t_ = torch.norm(v_t_)
847
- elif cfg_renorm_type == "channel":
848
- norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
849
- norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
850
- else:
851
- raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
852
- scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
853
- v_t = v_t_ * scale
854
- else:
855
- # No CFG
856
- pass
857
-
858
- return v_t
859
-
860
- def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
861
- packed_start_tokens, packed_key_value_indexes = list(), list()
862
- packed_query_position_ids = list()
863
-
864
- curr = 0
865
- for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
866
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
867
- packed_start_tokens.append(new_token_ids['bos_token_id'])
868
- packed_query_position_ids.append(curr_position_id)
869
- curr += curr_kvlen
870
-
871
- generation_input = {
872
- "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
873
- "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
874
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
875
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
876
- }
877
-
878
- return generation_input
879
-
880
- @torch.no_grad
881
- def generate_text(
882
- self,
883
- past_key_values: NaiveCache,
884
- packed_key_value_indexes: torch.LongTensor,
885
- key_values_lens: torch.IntTensor,
886
- packed_start_tokens: torch.LongTensor,
887
- packed_query_position_ids: torch.LongTensor,
888
- max_length: int,
889
- do_sample: bool = False,
890
- temperature: float = 1.0,
891
- end_token_id: int = None,
892
- ):
893
- step = 0
894
- # generated_sequence = [] # Removed for streaming
895
- curr_tokens = packed_start_tokens
896
- while step < max_length:
897
- # generated_sequence.append(curr_tokens) # Removed for streaming
898
- yield curr_tokens # Yield current tokens
899
-
900
- packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
901
- query_lens = torch.ones_like(curr_tokens)
902
- packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
903
- 0, len(key_values_lens),
904
- device=key_values_lens.device,
905
- dtype=key_values_lens.dtype
906
- )
907
-
908
- uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
909
- for i in range(len(uppacked)):
910
- uppacked[i] += i
911
- packed_key_value_indexes = torch.cat(uppacked, dim=0)
912
-
913
- extra_inputs = {}
914
- if self.use_moe:
915
- extra_inputs = {"mode": "und"}
916
-
917
- output = self.language_model.forward_inference(
918
- packed_query_sequence=packed_text_embedding,
919
- query_lens=query_lens,
920
- packed_query_position_ids=packed_query_position_ids,
921
- packed_query_indexes=packed_query_indexes,
922
- past_key_values=past_key_values,
923
- key_values_lens=key_values_lens,
924
- packed_key_value_indexes=packed_key_value_indexes,
925
- update_past_key_values=True,
926
- is_causal=True,
927
- **extra_inputs,
928
- )
929
- past_key_values = output.past_key_values
930
- packed_query_sequence = output.packed_query_sequence
931
- pred_logits = self.language_model.lm_head(packed_query_sequence)
932
-
933
- if do_sample:
934
- probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
935
- curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
936
- else:
937
- curr_tokens = torch.argmax(pred_logits, dim=-1)
938
-
939
- uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
940
- for i in range(len(uppacked)):
941
- uppacked[i] = torch.cat(
942
- [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
943
- )
944
- packed_key_value_indexes = torch.cat(uppacked, dim=0)
945
- key_values_lens = key_values_lens + 1
946
- packed_query_position_ids = packed_query_position_ids + 1
947
- step += 1
948
-
949
- if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1
950
- break
951
-
952
- # output_device = generated_sequence[0].device # Removed for streaming
953
- # return torch.stack([i.to(output_device) for i in generated_sequence], dim=0) # Removed for streaming
954
-
955
- # for evaluation
956
- @torch.no_grad()
957
- def chat(
958
- self,
959
- tokenizer,
960
- new_token_ids,
961
- image_transform,
962
- images,
963
- prompt,
964
- max_length: int,
965
- do_sample: bool = False,
966
- temperature: float = 1.0,
967
- ):
968
- device = next(self.parameters()).device
969
-
970
- if isinstance(new_token_ids, dict):
971
- for k, v in new_token_ids.items():
972
- if torch.is_tensor(v):
973
- new_token_ids[k] = v.to(device)
974
- elif torch.is_tensor(new_token_ids):
975
- new_token_ids = new_token_ids.to(device)
976
-
977
- # prefill
978
- past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers)
979
- newlens = [0]
980
- new_rope = [0]
981
-
982
- # add images
983
- for image in images:
984
- generation_input, newlens, new_rope = self.prepare_vit_images(
985
- curr_kvlens=newlens,
986
- curr_rope=new_rope,
987
- images=[image],
988
- transforms=image_transform,
989
- new_token_ids=new_token_ids,
990
- )
991
- for k, v in generation_input.items():
992
- if torch.is_tensor(v):
993
- generation_input[k] = v.to(device)
994
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
995
- past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input)
996
-
997
- # add text
998
- generation_input, newlens, new_rope = self.prepare_prompts(
999
- curr_kvlens=newlens,
1000
- curr_rope=new_rope,
1001
- prompts=[prompt],
1002
- tokenizer=tokenizer,
1003
- new_token_ids=new_token_ids,
1004
- )
1005
- for k, v in generation_input.items():
1006
- if torch.is_tensor(v):
1007
- generation_input[k] = v.to(device)
1008
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1009
- past_key_values = self.forward_cache_update_text(past_key_values, **generation_input)
1010
-
1011
- # decode
1012
- generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids)
1013
- for k, v in generation_input.items():
1014
- if torch.is_tensor(v):
1015
- generation_input[k] = v.to(device)
1016
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1017
- for unpacked_latent in self.generate_text(
1018
- past_key_values=past_key_values,
1019
- max_length=max_length,
1020
- do_sample=do_sample,
1021
- temperature=temperature,
1022
- end_token_id=new_token_ids['eos_token_id'],
1023
- **generation_input,
1024
- ):
1025
- output = tokenizer.decode(unpacked_latent[:,0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1026
  yield output
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import copy
5
+ from typing import List, Tuple, Optional
6
+ import matplotlib.pyplot as plt
7
+
8
+ from PIL import Image
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.nn.attention.flex_attention import create_block_mask
13
+ from transformers.configuration_utils import PretrainedConfig
14
+ from transformers.modeling_utils import PreTrainedModel
15
+
16
+ from data.data_utils import (
17
+ create_sparse_mask,
18
+ get_flattened_position_ids_extrapolate,
19
+ get_flattened_position_ids_interpolate,
20
+ patchify,
21
+ )
22
+ from .qwen2_navit import NaiveCache
23
+ from .modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding
24
+
25
+
26
+ class BagelConfig(PretrainedConfig):
27
+ def __init__(
28
+ self,
29
+ visual_gen=True,
30
+ visual_und=True,
31
+ llm_config=None,
32
+ vit_config=None,
33
+ vae_config=None,
34
+ latent_patch_size=2,
35
+ max_latent_size=32,
36
+ vit_max_num_patch_per_side=70,
37
+ connector_act="gelu_pytorch_tanh",
38
+ interpolate_pos=False,
39
+ timestep_shift=1.0,
40
+ **kwargs
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.visual_gen = visual_gen
44
+ self.visual_und = visual_und
45
+ self.llm_config = llm_config
46
+ self.vit_config = vit_config
47
+ self.vae_config = vae_config
48
+ self.latent_patch_size = latent_patch_size
49
+ self.max_latent_size = max_latent_size
50
+ self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
51
+ self.connector_act = connector_act
52
+ self.interpolate_pos = interpolate_pos
53
+ self.timestep_shift = timestep_shift
54
+
55
+
56
+ class Bagel(PreTrainedModel):
57
+ config_class = BagelConfig
58
+ base_model_prefix = 'bagel'
59
+
60
+ def __init__(self, language_model, vit_model, config: BagelConfig):
61
+ super().__init__(config)
62
+ self.language_model = language_model
63
+ self.hidden_size = config.llm_config.hidden_size
64
+ self.use_moe = "Mo" in config.llm_config.layer_module
65
+ self.num_heads = config.llm_config.num_attention_heads
66
+
67
+ if config.visual_gen:
68
+ self.latent_patch_size = config.latent_patch_size
69
+ self.timestep_shift = config.timestep_shift
70
+ self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
71
+ self.max_latent_size = config.max_latent_size
72
+ self.latent_channel = config.vae_config.z_channels
73
+ self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel
74
+ self.time_embedder = TimestepEmbedder(self.hidden_size)
75
+ self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
76
+ self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
77
+ self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
78
+
79
+ if config.visual_und:
80
+ self.vit_model = vit_model
81
+ self.vit_patch_size = config.vit_config.patch_size
82
+ self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
83
+ self.vit_hidden_size = config.vit_config.hidden_size
84
+ self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act)
85
+ self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size)
86
+
87
+ if config.interpolate_pos:
88
+ self.get_flattened_position_ids = get_flattened_position_ids_interpolate
89
+ else:
90
+ self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
91
+
92
+ self.config = config
93
+ self._init_weights()
94
+
95
+ def _init_weights(self):
96
+ if self.config.visual_gen:
97
+ nn.init.constant_(self.llm2vae.weight, 0)
98
+ nn.init.constant_(self.llm2vae.bias, 0)
99
+
100
+ def forward(
101
+ self,
102
+ sequence_length: int,
103
+ packed_text_ids: torch.LongTensor,
104
+ packed_text_indexes: torch.LongTensor,
105
+ sample_lens: List[int],
106
+ packed_position_ids: torch.LongTensor,
107
+ nested_attention_masks: List[torch.Tensor] = None,
108
+ split_lens: List[int] = None,
109
+ attn_modes: List[str] = None,
110
+ # for visual understanding
111
+ ce_loss_indexes: Optional[torch.BoolTensor] = None,
112
+ packed_label_ids: Optional[torch.LongTensor] = None,
113
+ packed_vit_tokens: Optional[torch.Tensor] = None,
114
+ packed_vit_token_indexes: Optional[torch.LongTensor] = None,
115
+ packed_vit_position_ids: Optional[torch.LongTensor] = None,
116
+ vit_token_seqlens: Optional[torch.IntTensor] = None,
117
+ # for visual generation
118
+ padded_latent: Optional[torch.Tensor] = None,
119
+ patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
120
+ packed_latent_position_ids: Optional[torch.LongTensor] = None,
121
+ packed_vae_token_indexes: Optional[torch.LongTensor] = None,
122
+ packed_timesteps: Optional[torch.LongTensor] = None,
123
+ mse_loss_indexes: Optional[torch.BoolTensor] = None,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Args:
127
+ sequence_length: length of sequence.
128
+ packed_text_ids: 1-D int tensor, packed text token ids.
129
+ packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
130
+ sample_lens: A list of N ints, length of each sample in packed_sequence.
131
+ nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
132
+ -inf means ignore.
133
+ packed_position_ids: packed 1-D positions, an image has only one global position shared
134
+ by all latent tokens.
135
+
136
+ packed_vit_tokens: packed patchified image tokens for vit model.
137
+ packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
138
+ packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
139
+ vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
140
+ packed_label_ids: 1-D int tensor, packed label token ids.
141
+ ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
142
+
143
+ padded_latent: padded latent from VAE encoder.
144
+ patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
145
+ packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
146
+ packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
147
+ packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
148
+ mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
149
+ """
150
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
151
+ packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size))
152
+ packed_sequence[packed_text_indexes] = packed_text_embedding
153
+
154
+ if nested_attention_masks is None:
155
+ sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device)
156
+ seqlen = sum(sample_lens)
157
+ block_mask = create_block_mask(
158
+ sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen,
159
+ device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True
160
+ )
161
+ attention_mask = block_mask
162
+ else:
163
+ attention_mask = nested_attention_masks
164
+
165
+ if self.config.visual_und:
166
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
167
+ cu_seqlens = cu_seqlens.to(torch.int32)
168
+ max_seqlen = torch.max(vit_token_seqlens).item()
169
+ packed_vit_token_embed = self.vit_model(
170
+ packed_pixel_values=packed_vit_tokens,
171
+ packed_flattened_position_ids=packed_vit_position_ids,
172
+ cu_seqlens=cu_seqlens,
173
+ max_seqlen=max_seqlen,
174
+ )
175
+ packed_vit_token_embed = self.connector(packed_vit_token_embed)
176
+ vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids)
177
+ packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb
178
+ packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
179
+
180
+ if self.config.visual_gen:
181
+ p = self.latent_patch_size
182
+ packed_latent = []
183
+ for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
184
+ latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
185
+ latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
186
+ packed_latent.append(latent)
187
+ packed_latent_clean = torch.cat(packed_latent, dim=0)
188
+
189
+ noise = torch.randn_like(packed_latent_clean)
190
+ packed_timesteps = torch.sigmoid(packed_timesteps)
191
+ packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps)
192
+ packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise
193
+ packed_timestep_embeds = self.time_embedder(packed_timesteps)
194
+ latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids)
195
+ packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb
196
+ packed_sequence[packed_vae_token_indexes] = packed_latent
197
+
198
+ extra_inputs = {}
199
+ if self.use_moe:
200
+ packed_und_token_indexes = packed_text_indexes
201
+ if packed_vit_token_indexes is not None:
202
+ packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0)
203
+ extra_inputs.update(
204
+ packed_und_token_indexes=packed_und_token_indexes,
205
+ packed_gen_token_indexes=packed_vae_token_indexes,
206
+ )
207
+
208
+ last_hidden_state = self.language_model(
209
+ packed_sequence=packed_sequence,
210
+ sample_lens=sample_lens,
211
+ attention_mask=attention_mask,
212
+ packed_position_ids=packed_position_ids,
213
+ **extra_inputs,
214
+ )
215
+
216
+ mse = None
217
+ if self.config.visual_gen:
218
+ packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes])
219
+ target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
220
+ has_mse = packed_timesteps > 0
221
+ mse = (packed_mse_preds - target[has_mse]) ** 2
222
+
223
+ ce = None
224
+ if ce_loss_indexes is not None:
225
+ packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes])
226
+ ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none")
227
+
228
+ return dict(mse=mse, ce=ce)
229
+
230
+
231
+ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
232
+ packed_text_ids = list()
233
+ packed_text_position_ids = list()
234
+ text_token_lens = list()
235
+ packed_text_indexes = list()
236
+ packed_key_value_indexes = list()
237
+
238
+ curr = 0
239
+ newlens, new_rope = list(), list()
240
+ for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope):
241
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
242
+ curr += curr_kvlen
243
+
244
+ text_ids = tokenizer.encode(prompt)
245
+ text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']]
246
+ text_token_lens.append(len(text_ids))
247
+ packed_text_ids.extend(text_ids)
248
+ packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids)))
249
+ packed_text_indexes.extend(range(curr, curr + len(text_ids)))
250
+ newlens.append(curr_kvlen + len(text_ids))
251
+ new_rope.append(curr_position_id + len(text_ids))
252
+ curr += len(text_ids)
253
+
254
+ generation_input = {
255
+ "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
256
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
257
+ "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long),
258
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
259
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
260
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
261
+ }
262
+
263
+ return generation_input, newlens, new_rope
264
+
265
+ @torch.no_grad
266
+ def forward_cache_update_text(
267
+ self,
268
+ past_key_values: NaiveCache,
269
+ packed_text_ids: torch.IntTensor,
270
+ packed_text_position_ids: torch.LongTensor,
271
+ text_token_lens: torch.LongTensor,
272
+ packed_text_indexes: torch.LongTensor,
273
+ packed_key_value_indexes: torch.LongTensor,
274
+ key_values_lens: torch.IntTensor,
275
+ ):
276
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
277
+
278
+ extra_inputs = {}
279
+ if self.use_moe:
280
+ extra_inputs = {"mode": "und"}
281
+
282
+ output = self.language_model.forward_inference(
283
+ packed_query_sequence=packed_text_embedding,
284
+ query_lens=text_token_lens,
285
+ packed_query_position_ids=packed_text_position_ids,
286
+ packed_query_indexes=packed_text_indexes,
287
+ past_key_values=past_key_values,
288
+ packed_key_value_indexes=packed_key_value_indexes,
289
+ key_values_lens=key_values_lens,
290
+ update_past_key_values=True,
291
+ is_causal=True,
292
+ **extra_inputs,
293
+ )
294
+ past_key_values = output.past_key_values
295
+
296
+ return past_key_values
297
+
298
+ def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids):
299
+ packed_vit_token_indexes = list()
300
+ vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list()
301
+ packed_text_ids, packed_text_indexes = list(), list()
302
+ packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
303
+ packed_key_value_indexes = list()
304
+
305
+ _curr = curr = 0
306
+ newlens, new_rope = list(), list()
307
+ for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
308
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
309
+ curr += curr_kvlen
310
+
311
+ packed_text_ids.append(new_token_ids['start_of_image'])
312
+ packed_text_indexes.append(_curr)
313
+ packed_indexes.append(curr)
314
+ curr += 1
315
+ _curr += 1
316
+
317
+ image_tensor = transforms(image)
318
+ vit_position_ids = self.get_flattened_position_ids(
319
+ image_tensor.size(1), image_tensor.size(2),
320
+ self.vit_patch_size,
321
+ max_num_patches_per_side=self.vit_max_num_patch_per_side
322
+ )
323
+ vit_tokens = patchify(image_tensor, self.vit_patch_size)
324
+ packed_vit_tokens.append(vit_tokens)
325
+ num_img_tokens = vit_tokens.shape[0]
326
+ packed_vit_position_ids.append(vit_position_ids)
327
+ vit_token_seqlens.append(num_img_tokens)
328
+ packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
329
+ packed_indexes.extend(range(curr, curr + num_img_tokens))
330
+ curr += num_img_tokens
331
+ _curr += num_img_tokens
332
+
333
+ packed_text_ids.append(new_token_ids['end_of_image'])
334
+ packed_text_indexes.append(_curr)
335
+ packed_indexes.append(curr)
336
+ curr += 1
337
+ _curr += 1
338
+
339
+ packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
340
+ packed_seqlens.append(num_img_tokens + 2)
341
+ newlens.append(curr_kvlen + num_img_tokens + 2)
342
+ new_rope.append(curr_position_id + 1)
343
+
344
+ generation_input = {
345
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
346
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
347
+ "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
348
+ "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
349
+ "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
350
+ "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long),
351
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
352
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
353
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
354
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
355
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
356
+ }
357
+
358
+ return generation_input, newlens, new_rope
359
+
360
+ @torch.no_grad
361
+ def forward_cache_update_vit(
362
+ self,
363
+ past_key_values: NaiveCache,
364
+ packed_text_ids: torch.LongTensor,
365
+ packed_text_indexes: torch.LongTensor,
366
+ packed_vit_tokens: torch.Tensor,
367
+ packed_vit_token_indexes: torch.LongTensor,
368
+ packed_vit_position_ids: torch.LongTensor,
369
+ vit_token_seqlens: torch.IntTensor,
370
+ packed_position_ids: torch.LongTensor,
371
+ packed_seqlens: torch.IntTensor,
372
+ packed_indexes: torch.LongTensor,
373
+ packed_key_value_indexes: torch.LongTensor,
374
+ key_values_lens: torch.IntTensor,
375
+ ):
376
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
377
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
378
+ packed_sequence[packed_text_indexes] = packed_text_embedding
379
+
380
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
381
+ cu_seqlens = cu_seqlens.to(torch.int32)
382
+ max_seqlen = torch.max(vit_token_seqlens).item()
383
+ packed_vit_token_embed = self.vit_model(
384
+ packed_pixel_values=packed_vit_tokens,
385
+ packed_flattened_position_ids=packed_vit_position_ids,
386
+ cu_seqlens=cu_seqlens,
387
+ max_seqlen=max_seqlen,
388
+ )
389
+ packed_vit_token_embed = self.connector(packed_vit_token_embed)
390
+ pos_emb = self.vit_pos_embed(packed_vit_position_ids)
391
+ packed_vit_token_embed = packed_vit_token_embed + pos_emb
392
+ packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
393
+
394
+ extra_inputs = {}
395
+ if self.use_moe:
396
+ extra_inputs = {"mode": "und"}
397
+
398
+ output = self.language_model.forward_inference(
399
+ packed_query_sequence=packed_sequence,
400
+ query_lens=packed_seqlens,
401
+ packed_query_position_ids=packed_position_ids,
402
+ packed_query_indexes=packed_indexes,
403
+ past_key_values=past_key_values,
404
+ packed_key_value_indexes=packed_key_value_indexes,
405
+ key_values_lens=key_values_lens,
406
+ update_past_key_values=True,
407
+ is_causal=False,
408
+ **extra_inputs,
409
+ )
410
+ past_key_values = output.past_key_values
411
+
412
+ return past_key_values
413
+
414
+ def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0):
415
+ patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
416
+ packed_vae_token_indexes = list()
417
+ packed_text_ids, packed_text_indexes = list(), list()
418
+ packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
419
+ packed_key_value_indexes = list()
420
+
421
+ _curr = curr = 0
422
+ vae_image_tensors = list()
423
+ newlens, new_rope = list(), list()
424
+ for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
425
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
426
+ curr += curr_kvlen
427
+
428
+ packed_text_ids.append(new_token_ids['start_of_image'])
429
+ packed_text_indexes.append(_curr)
430
+ packed_indexes.append(curr)
431
+ curr += 1
432
+ _curr += 1
433
+
434
+ image_tensor = transforms(image)
435
+ vae_image_tensors.append(image_tensor)
436
+ vae_posiiton_ids = self.get_flattened_position_ids(
437
+ image_tensor.size(1), image_tensor.size(2),
438
+ self.latent_downsample,
439
+ max_num_patches_per_side=self.max_latent_size
440
+ )
441
+ packed_vae_position_ids.append(vae_posiiton_ids)
442
+ H, W = image_tensor.shape[1:]
443
+ h = H // self.latent_downsample
444
+ w = W // self.latent_downsample
445
+ patchified_vae_latent_shapes.append((h, w))
446
+
447
+ num_img_tokens = w * h
448
+ packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
449
+ packed_indexes.extend(range(curr, curr + num_img_tokens))
450
+ curr += num_img_tokens
451
+ _curr += num_img_tokens
452
+
453
+ packed_text_ids.append(new_token_ids['end_of_image'])
454
+ packed_text_indexes.append(_curr)
455
+ packed_indexes.append(curr)
456
+ curr += 1
457
+ _curr += 1
458
+
459
+ packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
460
+ packed_seqlens.append(num_img_tokens + 2)
461
+ newlens.append(curr_kvlen + num_img_tokens + 2)
462
+ new_rope.append(curr_position_id + 1)
463
+
464
+ image_sizes = [item.shape for item in vae_image_tensors]
465
+ max_image_size = [max(item) for item in list(zip(*image_sizes))]
466
+ padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
467
+ for i, image_tensor in enumerate(vae_image_tensors):
468
+ padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
469
+
470
+ generation_input = {
471
+ "padded_images": padded_images,
472
+ "patchified_vae_latent_shapes": patchified_vae_latent_shapes,
473
+ "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
474
+ "packed_timesteps": torch.tensor([timestep]),
475
+ "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
476
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
477
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
478
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
479
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
480
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
481
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
482
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
483
+ }
484
+
485
+ return generation_input, newlens, new_rope
486
+
487
+ @torch.no_grad
488
+ def forward_cache_update_vae(
489
+ self,
490
+ vae_model,
491
+ past_key_values: NaiveCache,
492
+ padded_images: torch.Tensor,
493
+ patchified_vae_latent_shapes: List,
494
+ packed_vae_position_ids: torch.LongTensor,
495
+ packed_timesteps: torch.Tensor,
496
+ packed_vae_token_indexes: torch.LongTensor,
497
+ packed_text_ids: torch.LongTensor,
498
+ packed_text_indexes: torch.LongTensor,
499
+ packed_position_ids: torch.LongTensor,
500
+ packed_seqlens: torch.IntTensor,
501
+ packed_indexes: torch.LongTensor,
502
+ key_values_lens: torch.IntTensor,
503
+ packed_key_value_indexes: torch.Tensor,
504
+ ):
505
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
506
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
507
+ packed_sequence[packed_text_indexes] = packed_text_embedding
508
+
509
+ padded_latent = vae_model.encode(padded_images)
510
+
511
+ p = self.latent_patch_size
512
+ packed_latent = list()
513
+ for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
514
+ latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
515
+ latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
516
+ packed_latent.append(latent)
517
+ packed_latent = torch.cat(packed_latent, dim=0)
518
+ packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
519
+ packed_timestep_embeds = self.time_embedder(packed_timesteps)
520
+ packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
521
+ packed_sequence[packed_vae_token_indexes] = packed_latent
522
+
523
+ extra_inputs = {}
524
+ if self.use_moe:
525
+ extra_inputs = {
526
+ "mode": "gen",
527
+ "packed_vae_token_indexes": packed_vae_token_indexes,
528
+ "packed_text_indexes": packed_text_indexes
529
+ }
530
+
531
+ output = self.language_model.forward_inference(
532
+ packed_query_sequence=packed_sequence,
533
+ query_lens=packed_seqlens,
534
+ packed_query_position_ids=packed_position_ids,
535
+ packed_query_indexes=packed_indexes,
536
+ past_key_values=past_key_values,
537
+ key_values_lens=key_values_lens,
538
+ packed_key_value_indexes=packed_key_value_indexes,
539
+ update_past_key_values=True,
540
+ is_causal=False,
541
+ **extra_inputs,
542
+ )
543
+ past_key_values = output.past_key_values
544
+
545
+ return past_key_values
546
+
547
+ def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
548
+ packed_text_ids, packed_text_indexes = list(), list()
549
+ packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list()
550
+ packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list()
551
+ packed_key_value_indexes = list()
552
+
553
+ query_curr = curr = 0
554
+ for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
555
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
556
+ curr += curr_kvlen
557
+
558
+ packed_text_ids.append(new_token_ids['start_of_image'])
559
+ packed_text_indexes.append(query_curr)
560
+ packed_indexes.append(curr)
561
+ curr += 1
562
+ query_curr += 1
563
+
564
+ vae_posiiton_ids = self.get_flattened_position_ids(
565
+ H, W,
566
+ self.latent_downsample,
567
+ max_num_patches_per_side=self.max_latent_size
568
+ )
569
+ packed_vae_position_ids.append(vae_posiiton_ids)
570
+
571
+ h, w = H // self.latent_downsample, W // self.latent_downsample
572
+ num_image_tokens = h * w
573
+ packed_init_noises.append(
574
+ torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2)
575
+ )
576
+ packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens))
577
+ packed_indexes.extend(range(curr, curr + num_image_tokens))
578
+ curr += num_image_tokens
579
+ query_curr += num_image_tokens
580
+
581
+ packed_text_ids.append(new_token_ids['end_of_image'])
582
+ packed_text_indexes.append(query_curr)
583
+ packed_indexes.append(curr)
584
+ curr += 1
585
+ query_curr += 1
586
+
587
+ packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
588
+ packed_seqlens.append(num_image_tokens + 2)
589
+
590
+ generation_input = {
591
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
592
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
593
+ "packed_init_noises": torch.cat(packed_init_noises, dim=0),
594
+ "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
595
+ "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
596
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
597
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
598
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
599
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
600
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
601
+ }
602
+
603
+ return generation_input
604
+
605
+ def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
606
+ packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list()
607
+
608
+ query_curr = curr = 0
609
+ for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
610
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
611
+ curr += curr_kvlen
612
+
613
+ packed_indexes.append(curr)
614
+ curr += 1
615
+ query_curr += 1
616
+
617
+ h, w = H // self.latent_downsample, W // self.latent_downsample
618
+ num_image_tokens = h * w
619
+ packed_indexes.extend(range(curr, curr + num_image_tokens))
620
+ curr += num_image_tokens
621
+ query_curr += num_image_tokens
622
+
623
+ packed_indexes.append(curr)
624
+ curr += 1
625
+ query_curr += 1
626
+
627
+ packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
628
+
629
+ generation_input = {
630
+ "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
631
+ "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
632
+ "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
633
+ "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
634
+ }
635
+
636
+ return generation_input
637
+
638
+ @torch.no_grad
639
+ def generate_image(
640
+ self,
641
+ packed_text_ids: torch.LongTensor,
642
+ packed_text_indexes: torch.LongTensor,
643
+ packed_init_noises: torch.Tensor,
644
+ packed_vae_position_ids: torch.LongTensor,
645
+ packed_vae_token_indexes: torch.LongTensor,
646
+ packed_seqlens: torch.IntTensor,
647
+ packed_position_ids: torch.LongTensor,
648
+ packed_indexes: torch.LongTensor,
649
+ past_key_values: NaiveCache,
650
+ key_values_lens: torch.IntTensor,
651
+ packed_key_value_indexes: torch.LongTensor,
652
+ num_timesteps: int = 24,
653
+ timestep_shift: float = 1.0,
654
+ cfg_renorm_min: float = 0.0,
655
+ cfg_renorm_type: str = "global",
656
+ cfg_interval: Optional[Tuple[float, float]] = [0, 1],
657
+ # cfg_text
658
+ cfg_text_scale: float = 1.0,
659
+ cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
660
+ cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
661
+ cfg_text_past_key_values: Optional[NaiveCache] = None,
662
+ cfg_text_key_values_lens: Optional[torch.IntTensor] = None,
663
+ cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
664
+ # cfg_img
665
+ cfg_img_scale: float = 1.0,
666
+ cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
667
+ cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
668
+ cfg_img_past_key_values: Optional[NaiveCache] = None,
669
+ cfg_img_key_values_lens: Optional[torch.IntTensor] = None,
670
+ cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
671
+ cfg_type: str = "parallel",
672
+ ):
673
+ x_t = packed_init_noises
674
+
675
+ timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
676
+ timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
677
+ dts = timesteps[:-1] - timesteps[1:]
678
+ timesteps = timesteps[:-1]
679
+
680
+ for i, t in enumerate(timesteps):
681
+
682
+ timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
683
+ if t > cfg_interval[0] and t <= cfg_interval[1]:
684
+ cfg_text_scale_ = cfg_text_scale
685
+ cfg_img_scale_ = cfg_img_scale
686
+ else:
687
+ cfg_text_scale_ = 1.0
688
+ cfg_img_scale_ = 1.0
689
+ v_t = self._forward_flow(
690
+ x_t=x_t,
691
+ timestep=timestep,
692
+ packed_vae_token_indexes=packed_vae_token_indexes,
693
+ packed_vae_position_ids=packed_vae_position_ids,
694
+ packed_text_ids=packed_text_ids,
695
+ packed_text_indexes=packed_text_indexes,
696
+ packed_position_ids=packed_position_ids,
697
+ packed_indexes=packed_indexes,
698
+ packed_seqlens=packed_seqlens,
699
+ key_values_lens=key_values_lens,
700
+ past_key_values=past_key_values,
701
+ packed_key_value_indexes=packed_key_value_indexes,
702
+ cfg_renorm_min=cfg_renorm_min,
703
+ cfg_renorm_type=cfg_renorm_type,
704
+ # cfg_text
705
+ cfg_text_scale=cfg_text_scale_,
706
+ cfg_text_packed_position_ids=cfg_text_packed_position_ids,
707
+ cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
708
+ cfg_text_key_values_lens=cfg_text_key_values_lens,
709
+ cfg_text_past_key_values=cfg_text_past_key_values,
710
+ cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
711
+ # cfg_img
712
+ cfg_img_scale=cfg_img_scale_,
713
+ cfg_img_packed_position_ids=cfg_img_packed_position_ids,
714
+ cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
715
+ cfg_img_key_values_lens=cfg_img_key_values_lens,
716
+ cfg_img_past_key_values=cfg_img_past_key_values,
717
+ cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
718
+ cfg_type=cfg_type,
719
+ )
720
+
721
+ x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise
722
+
723
+ unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
724
+ return unpacked_latent
725
+
726
+ @torch.no_grad
727
+ def _forward_flow(
728
+ self,
729
+ x_t: torch.Tensor,
730
+ timestep: torch.LongTensor,
731
+ packed_vae_token_indexes: torch.LongTensor,
732
+ packed_vae_position_ids: torch.LongTensor,
733
+ packed_text_ids: torch.LongTensor,
734
+ packed_text_indexes: torch.LongTensor,
735
+ packed_indexes: torch.LongTensor,
736
+ packed_position_ids: torch.LongTensor,
737
+ packed_seqlens: torch.IntTensor,
738
+ key_values_lens: torch.IntTensor,
739
+ past_key_values: NaiveCache,
740
+ packed_key_value_indexes: torch.LongTensor,
741
+ cfg_renorm_min: float = 0.0,
742
+ cfg_renorm_type: str = "global",
743
+ # cfg_text
744
+ cfg_text_scale: float = 1.0,
745
+ cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
746
+ cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
747
+ cfg_text_key_values_lens: Optional[torch.Tensor] = None,
748
+ cfg_text_past_key_values: Optional[NaiveCache] = None,
749
+ cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
750
+ # cfg_img
751
+ cfg_img_scale: float = 1.0,
752
+ cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
753
+ cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
754
+ cfg_img_key_values_lens: Optional[torch.Tensor] = None,
755
+ cfg_img_past_key_values: Optional[NaiveCache] = None,
756
+ cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
757
+ cfg_type: str = "parallel",
758
+ ):
759
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
760
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
761
+ packed_sequence[packed_text_indexes] = packed_text_embedding
762
+
763
+ assert timestep.unique().shape[0] == 1
764
+ packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
765
+ packed_timestep_embeds = self.time_embedder(timestep)
766
+ x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
767
+ packed_sequence[packed_vae_token_indexes] = x_t
768
+
769
+ extra_inputs = {}
770
+ if self.use_moe:
771
+ extra_inputs = {
772
+ "mode": "gen",
773
+ "packed_vae_token_indexes": packed_vae_token_indexes,
774
+ "packed_text_indexes": packed_text_indexes
775
+ }
776
+
777
+ output = self.language_model.forward_inference(
778
+ packed_query_sequence=packed_sequence,
779
+ query_lens=packed_seqlens,
780
+ packed_query_position_ids=packed_position_ids,
781
+ packed_query_indexes=packed_indexes,
782
+ past_key_values=past_key_values,
783
+ key_values_lens=key_values_lens,
784
+ packed_key_value_indexes=packed_key_value_indexes,
785
+ update_past_key_values=False,
786
+ is_causal=False,
787
+ **extra_inputs,
788
+ )
789
+ v_t = self.llm2vae(output.packed_query_sequence)
790
+ v_t = v_t[packed_vae_token_indexes]
791
+
792
+ if cfg_text_scale > 1.0:
793
+ cfg_text_output = self.language_model.forward_inference(
794
+ packed_query_sequence=packed_sequence,
795
+ query_lens=packed_seqlens,
796
+ packed_query_position_ids=cfg_text_packed_position_ids,
797
+ packed_query_indexes=cfg_text_packed_query_indexes,
798
+ past_key_values=cfg_text_past_key_values,
799
+ key_values_lens=cfg_text_key_values_lens,
800
+ packed_key_value_indexes=cfg_text_packed_key_value_indexes,
801
+ update_past_key_values=False,
802
+ is_causal=False,
803
+ **extra_inputs,
804
+ )
805
+ cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
806
+ cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]
807
+
808
+ if cfg_img_scale > 1.0:
809
+ cfg_img_output = self.language_model.forward_inference(
810
+ packed_query_sequence=packed_sequence,
811
+ query_lens=packed_seqlens,
812
+ packed_query_position_ids=cfg_img_packed_position_ids,
813
+ packed_query_indexes=cfg_img_packed_query_indexes,
814
+ past_key_values=cfg_img_past_key_values,
815
+ key_values_lens=cfg_img_key_values_lens,
816
+ packed_key_value_indexes=cfg_img_packed_key_value_indexes,
817
+ update_past_key_values=False,
818
+ is_causal=False,
819
+ **extra_inputs,
820
+ )
821
+ cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
822
+ cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]
823
+
824
+ if cfg_text_scale > 1.0:
825
+ if cfg_renorm_type == "text_channel":
826
+ v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
827
+ norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
828
+ norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
829
+ scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
830
+ v_t_text = v_t_text_ * scale
831
+ if cfg_img_scale > 1.0:
832
+ v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
833
+ else:
834
+ v_t = v_t_text
835
+ else:
836
+ v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
837
+
838
+ if cfg_img_scale > 1.0:
839
+ v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
840
+ else:
841
+ v_t_ = v_t_text_
842
+
843
+ # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
844
+ if cfg_renorm_type == "global":
845
+ norm_v_t = torch.norm(v_t)
846
+ norm_v_t_ = torch.norm(v_t_)
847
+ elif cfg_renorm_type == "channel":
848
+ norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
849
+ norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
850
+ else:
851
+ raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
852
+ scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
853
+ v_t = v_t_ * scale
854
+ else:
855
+ # No CFG
856
+ pass
857
+
858
+ return v_t
859
+
860
+ def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
861
+ packed_start_tokens, packed_key_value_indexes = list(), list()
862
+ packed_query_position_ids = list()
863
+
864
+ curr = 0
865
+ for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
866
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
867
+ packed_start_tokens.append(new_token_ids['bos_token_id'])
868
+ packed_query_position_ids.append(curr_position_id)
869
+ curr += curr_kvlen
870
+
871
+ generation_input = {
872
+ "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
873
+ "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
874
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
875
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
876
+ }
877
+
878
+ return generation_input
879
+
880
+ @torch.no_grad
881
+ def generate_text(
882
+ self,
883
+ past_key_values: NaiveCache,
884
+ packed_key_value_indexes: torch.LongTensor,
885
+ key_values_lens: torch.IntTensor,
886
+ packed_start_tokens: torch.LongTensor,
887
+ packed_query_position_ids: torch.LongTensor,
888
+ max_length: int,
889
+ do_sample: bool = False,
890
+ temperature: float = 1.0,
891
+ end_token_id: int = None,
892
+ ):
893
+ """
894
+ Generates text token by token in a streaming fashion.
895
+
896
+ This function is a generator that yields one token at a time. It replicates
897
+ the behavior of the original batch generation function, including the handling
898
+ of start tokens and the end-of-sequence token.
899
+ """
900
+ curr_tokens = packed_start_tokens
901
+
902
+ for _ in range(max_length):
903
+ # The original function would append `curr_tokens` to a list at this point.
904
+ # Instead, we yield it to the caller, enabling streaming.
905
+ yield curr_tokens
906
+
907
+ packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
908
+ query_lens = torch.ones_like(curr_tokens)
909
+ packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
910
+ 0, len(key_values_lens),
911
+ device=key_values_lens.device,
912
+ dtype=key_values_lens.dtype
913
+ )
914
+
915
+ # This block modifies packed_key_value_indexes before the forward pass,
916
+ # preserving the specific logic for NaViT-style packed inputs.
917
+ # The typo 'uppacked' is kept to match the original source code.
918
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
919
+ for i in range(len(uppacked)):
920
+ uppacked[i] += i
921
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
922
+
923
+ extra_inputs = {}
924
+ if self.use_moe:
925
+ extra_inputs = {"mode": "und"}
926
+
927
+ output = self.language_model.forward_inference(
928
+ packed_query_sequence=packed_text_embedding,
929
+ query_lens=query_lens,
930
+ packed_query_position_ids=packed_query_position_ids,
931
+ packed_query_indexes=packed_query_indexes,
932
+ past_key_values=past_key_values,
933
+ key_values_lens=key_values_lens,
934
+ packed_key_value_indexes=packed_key_value_indexes,
935
+ update_past_key_values=True,
936
+ is_causal=True,
937
+ **extra_inputs,
938
+ )
939
+ past_key_values = output.past_key_values
940
+ packed_query_sequence = output.packed_query_sequence
941
+ pred_logits = self.language_model.lm_head(packed_query_sequence)
942
+
943
+ # Sample the next token
944
+ if do_sample:
945
+ probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
946
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
947
+ else:
948
+ next_tokens = torch.argmax(pred_logits, dim=-1)
949
+
950
+ # The stop condition is checked on the newly generated token. If it's the
951
+ # end token, we break the loop. This token will not be yielded.
952
+ if end_token_id is not None and next_tokens[0] == end_token_id: # only support batch=1
953
+ break
954
+
955
+ # This block updates the state variables for the next iteration. It reads
956
+ # the already-modified `packed_key_value_indexes` and updates it further.
957
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
958
+ for i in range(len(uppacked)):
959
+ uppacked[i] = torch.cat(
960
+ [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
961
+ )
962
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
963
+ key_values_lens = key_values_lens + 1
964
+ packed_query_position_ids = packed_query_position_ids + 1
965
+
966
+ # The newly generated token becomes the input for the next loop iteration.
967
+ curr_tokens = next_tokens
968
+
969
+ # for evaluation
970
+ @torch.no_grad()
971
+ def chat(
972
+ self,
973
+ tokenizer,
974
+ new_token_ids,
975
+ image_transform,
976
+ images,
977
+ prompt,
978
+ max_length: int,
979
+ do_sample: bool = False,
980
+ temperature: float = 1.0,
981
+ ):
982
+ device = next(self.parameters()).device
983
+
984
+ if isinstance(new_token_ids, dict):
985
+ for k, v in new_token_ids.items():
986
+ if torch.is_tensor(v):
987
+ new_token_ids[k] = v.to(device)
988
+ elif torch.is_tensor(new_token_ids):
989
+ new_token_ids = new_token_ids.to(device)
990
+
991
+ # prefill
992
+ past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers)
993
+ newlens = [0]
994
+ new_rope = [0]
995
+
996
+ # add images
997
+ for image in images:
998
+ generation_input, newlens, new_rope = self.prepare_vit_images(
999
+ curr_kvlens=newlens,
1000
+ curr_rope=new_rope,
1001
+ images=[image],
1002
+ transforms=image_transform,
1003
+ new_token_ids=new_token_ids,
1004
+ )
1005
+ for k, v in generation_input.items():
1006
+ if torch.is_tensor(v):
1007
+ generation_input[k] = v.to(device)
1008
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1009
+ past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input)
1010
+
1011
+ # add text
1012
+ generation_input, newlens, new_rope = self.prepare_prompts(
1013
+ curr_kvlens=newlens,
1014
+ curr_rope=new_rope,
1015
+ prompts=[prompt],
1016
+ tokenizer=tokenizer,
1017
+ new_token_ids=new_token_ids,
1018
+ )
1019
+ for k, v in generation_input.items():
1020
+ if torch.is_tensor(v):
1021
+ generation_input[k] = v.to(device)
1022
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1023
+ past_key_values = self.forward_cache_update_text(past_key_values, **generation_input)
1024
+
1025
+ # decode
1026
+ generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids)
1027
+ for k, v in generation_input.items():
1028
+ if torch.is_tensor(v):
1029
+ generation_input[k] = v.to(device)
1030
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1031
+ for unpacked_latent in self.generate_text(
1032
+ past_key_values=past_key_values,
1033
+ max_length=max_length,
1034
+ do_sample=do_sample,
1035
+ temperature=temperature,
1036
+ end_token_id=new_token_ids['eos_token_id'],
1037
+ **generation_input,
1038
+ ):
1039
+ output = tokenizer.decode(unpacked_latent[:,0])
1040
  yield output