Spaces:
Runtime error
Runtime error
Commit
·
212945c
1
Parent(s):
c4f1082
Update llama/m2ugen.py
Browse files- llama/m2ugen.py +70 -25
llama/m2ugen.py
CHANGED
|
@@ -231,9 +231,9 @@ class M2UGen(nn.Module):
|
|
| 231 |
self.music_decoder = self.args.music_decoder.lower()
|
| 232 |
|
| 233 |
# 4. prefix
|
| 234 |
-
self.query_layer =
|
| 235 |
self.query_len = 1
|
| 236 |
-
self.prefix_query = nn.Embedding(self.query_layer * self.query_len, self.model_args.dim).to("cuda:0")
|
| 237 |
|
| 238 |
# 5. knn
|
| 239 |
self.knn = knn
|
|
@@ -492,30 +492,52 @@ class M2UGen(nn.Module):
|
|
| 492 |
h = self.llama.tok_embeddings(tokens).to("cuda:0")
|
| 493 |
freqs_cis = self.llama.freqs_cis.to("cuda:0")
|
| 494 |
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
|
| 495 |
-
|
| 496 |
-
feats = torch.zeros((1, 1, 4096)).to("cuda:0")
|
| 497 |
-
if audio_feats is not None:
|
| 498 |
-
feats += audio_feats
|
| 499 |
-
if video_feats is not None:
|
| 500 |
-
feats += video_feats
|
| 501 |
-
if image_feats is not None:
|
| 502 |
-
feats += image_feats
|
| 503 |
|
| 504 |
mask = None
|
| 505 |
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:0")
|
| 506 |
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
| 507 |
|
| 508 |
music_output_embedding = []
|
| 509 |
-
for layer in self.llama.layers[:-
|
| 510 |
h = layer(h, 0, freqs_cis, mask)
|
| 511 |
music_output_embedding.append(h)
|
| 512 |
|
| 513 |
-
prefix_query = self.prefix_query.weight.reshape(
|
|
|
|
| 514 |
|
| 515 |
prefix_index = 0
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
h = self.llama.norm(h)
|
| 521 |
output = self.llama.output(h[:, -1, :])
|
|
@@ -523,30 +545,53 @@ class M2UGen(nn.Module):
|
|
| 523 |
return output.float(), torch.cat(music_output_embedding[-1:], dim=1)
|
| 524 |
|
| 525 |
def forward(self, tokens, labels, audios=None, imgs=None, videos=None, music_caption=None):
|
| 526 |
-
|
| 527 |
if audios is not None:
|
| 528 |
-
|
| 529 |
if videos is not None:
|
| 530 |
-
|
| 531 |
if imgs is not None:
|
| 532 |
-
|
| 533 |
_bsz, seqlen = tokens.shape
|
| 534 |
|
| 535 |
h = self.llama.tok_embeddings(tokens.to(self.device))
|
| 536 |
freqs_cis = self.llama.freqs_cis.to(h.device)
|
| 537 |
freqs_cis = freqs_cis[:seqlen]
|
| 538 |
-
mask = None
|
| 539 |
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
|
| 540 |
mask = torch.triu(mask, diagonal=0 + 1).type_as(h)
|
| 541 |
|
| 542 |
-
for layer in self.llama.layers[:-
|
| 543 |
h = layer(h, 0, freqs_cis, mask)
|
| 544 |
-
prefix_query = self.prefix_query.weight.reshape(
|
|
|
|
|
|
|
| 545 |
prefix_index = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
|
| 551 |
final_hidden = h
|
| 552 |
h = self.llama.norm(h)
|
|
|
|
| 231 |
self.music_decoder = self.args.music_decoder.lower()
|
| 232 |
|
| 233 |
# 4. prefix
|
| 234 |
+
self.query_layer = 6
|
| 235 |
self.query_len = 1
|
| 236 |
+
self.prefix_query = nn.Embedding(self.query_layer * 3 * self.query_len, self.model_args.dim).to("cuda:0")
|
| 237 |
|
| 238 |
# 5. knn
|
| 239 |
self.knn = knn
|
|
|
|
| 492 |
h = self.llama.tok_embeddings(tokens).to("cuda:0")
|
| 493 |
freqs_cis = self.llama.freqs_cis.to("cuda:0")
|
| 494 |
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
mask = None
|
| 497 |
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:0")
|
| 498 |
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
| 499 |
|
| 500 |
music_output_embedding = []
|
| 501 |
+
for layer in self.llama.layers[:-3 * self.query_layer]:
|
| 502 |
h = layer(h, 0, freqs_cis, mask)
|
| 503 |
music_output_embedding.append(h)
|
| 504 |
|
| 505 |
+
prefix_query = self.prefix_query.weight.reshape(
|
| 506 |
+
self.query_layer * 3, 1, 4096).unsqueeze(1)
|
| 507 |
|
| 508 |
prefix_index = 0
|
| 509 |
+
if audio_feats is not None:
|
| 510 |
+
for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
|
| 511 |
+
h = layer(h, 0, freqs_cis, mask, audio_feats + prefix_query[prefix_index])
|
| 512 |
+
music_output_embedding.append(h)
|
| 513 |
+
prefix_index = prefix_index + 1
|
| 514 |
+
else:
|
| 515 |
+
for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
|
| 516 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
| 517 |
+
music_output_embedding.append(h)
|
| 518 |
+
prefix_index = prefix_index + 1
|
| 519 |
+
|
| 520 |
+
if image_feats is not None:
|
| 521 |
+
for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
|
| 522 |
+
h = layer(h, 0, freqs_cis, mask, image_feats + prefix_query[prefix_index])
|
| 523 |
+
music_output_embedding.append(h)
|
| 524 |
+
prefix_index = prefix_index + 1
|
| 525 |
+
else:
|
| 526 |
+
for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
|
| 527 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
| 528 |
+
music_output_embedding.append(h)
|
| 529 |
+
prefix_index = prefix_index + 1
|
| 530 |
+
|
| 531 |
+
if video_feats is not None:
|
| 532 |
+
for layer in self.llama.layers[-1 * self.query_layer:]:
|
| 533 |
+
h = layer(h, 0, freqs_cis, mask, video_feats + prefix_query[prefix_index])
|
| 534 |
+
music_output_embedding.append(h)
|
| 535 |
+
prefix_index = prefix_index + 1
|
| 536 |
+
else:
|
| 537 |
+
for layer in self.llama.layers[-1 * self.query_layer:]:
|
| 538 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
| 539 |
+
music_output_embedding.append(h)
|
| 540 |
+
prefix_index = prefix_index + 1
|
| 541 |
|
| 542 |
h = self.llama.norm(h)
|
| 543 |
output = self.llama.output(h[:, -1, :])
|
|
|
|
| 545 |
return output.float(), torch.cat(music_output_embedding[-1:], dim=1)
|
| 546 |
|
| 547 |
def forward(self, tokens, labels, audios=None, imgs=None, videos=None, music_caption=None):
|
| 548 |
+
audio_feats, video_feats, image_feats = None, None, None
|
| 549 |
if audios is not None:
|
| 550 |
+
audio_feats = self.forward_audio({'Audio': [audios, 1]})
|
| 551 |
if videos is not None:
|
| 552 |
+
video_feats = self.forward_video({'Video': [videos, 1]})
|
| 553 |
if imgs is not None:
|
| 554 |
+
image_feats = self.forward_image({'Image': [imgs, 1]})
|
| 555 |
_bsz, seqlen = tokens.shape
|
| 556 |
|
| 557 |
h = self.llama.tok_embeddings(tokens.to(self.device))
|
| 558 |
freqs_cis = self.llama.freqs_cis.to(h.device)
|
| 559 |
freqs_cis = freqs_cis[:seqlen]
|
|
|
|
| 560 |
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
|
| 561 |
mask = torch.triu(mask, diagonal=0 + 1).type_as(h)
|
| 562 |
|
| 563 |
+
for layer in self.llama.layers[:-3 * self.query_layer]:
|
| 564 |
h = layer(h, 0, freqs_cis, mask)
|
| 565 |
+
prefix_query = self.prefix_query.weight.reshape(
|
| 566 |
+
self.query_layer * 3, 1, 4096).unsqueeze(1)
|
| 567 |
+
|
| 568 |
prefix_index = 0
|
| 569 |
+
if audio_feats is not None:
|
| 570 |
+
for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
|
| 571 |
+
h = layer(h, 0, freqs_cis, mask, audio_feats + prefix_query[prefix_index])
|
| 572 |
+
prefix_index = prefix_index + 1
|
| 573 |
+
else:
|
| 574 |
+
for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
|
| 575 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
| 576 |
+
prefix_index = prefix_index + 1
|
| 577 |
|
| 578 |
+
if image_feats is not None:
|
| 579 |
+
for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
|
| 580 |
+
h = layer(h, 0, freqs_cis, mask, image_feats + prefix_query[prefix_index])
|
| 581 |
+
prefix_index = prefix_index + 1
|
| 582 |
+
else:
|
| 583 |
+
for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
|
| 584 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
| 585 |
+
prefix_index = prefix_index + 1
|
| 586 |
+
|
| 587 |
+
if video_feats is not None:
|
| 588 |
+
for layer in self.llama.layers[-1 * self.query_layer:]:
|
| 589 |
+
h = layer(h, 0, freqs_cis, mask, video_feats + prefix_query[prefix_index])
|
| 590 |
+
prefix_index = prefix_index + 1
|
| 591 |
+
else:
|
| 592 |
+
for layer in self.llama.layers[-1 * self.query_layer:]:
|
| 593 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
| 594 |
+
prefix_index = prefix_index + 1
|
| 595 |
|
| 596 |
final_hidden = h
|
| 597 |
h = self.llama.norm(h)
|