|
|
from typing import List, Optional, Union |
|
|
|
|
|
import torch |
|
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
|
|
|
|
|
|
def _get_t5_prompt_embeds( |
|
|
tokenizer: T5Tokenizer, |
|
|
text_encoder: T5EncoderModel, |
|
|
prompt: Union[str, List[str]], |
|
|
num_videos_per_prompt: int = 1, |
|
|
max_sequence_length: int = 226, |
|
|
device: Optional[torch.device] = None, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
text_input_ids=None, |
|
|
): |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
batch_size = len(prompt) |
|
|
|
|
|
if tokenizer is not None: |
|
|
text_inputs = tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=max_sequence_length, |
|
|
truncation=True, |
|
|
add_special_tokens=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids |
|
|
else: |
|
|
if text_input_ids is None: |
|
|
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") |
|
|
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device))[0] |
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
|
|
|
|
|
_, seq_len, _ = prompt_embeds.shape |
|
|
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) |
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) |
|
|
|
|
|
return prompt_embeds |
|
|
|
|
|
|
|
|
def encode_prompt( |
|
|
tokenizer: T5Tokenizer, |
|
|
text_encoder: T5EncoderModel, |
|
|
prompt: Union[str, List[str]], |
|
|
num_videos_per_prompt: int = 1, |
|
|
max_sequence_length: int = 226, |
|
|
device: Optional[torch.device] = None, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
text_input_ids=None, |
|
|
): |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
prompt_embeds = _get_t5_prompt_embeds( |
|
|
tokenizer, |
|
|
text_encoder, |
|
|
prompt=prompt, |
|
|
num_videos_per_prompt=num_videos_per_prompt, |
|
|
max_sequence_length=max_sequence_length, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
text_input_ids=text_input_ids, |
|
|
) |
|
|
return prompt_embeds |
|
|
|
|
|
|
|
|
def compute_prompt_embeddings( |
|
|
tokenizer: T5Tokenizer, |
|
|
text_encoder: T5EncoderModel, |
|
|
prompt: str, |
|
|
max_sequence_length: int, |
|
|
device: torch.device, |
|
|
dtype: torch.dtype, |
|
|
requires_grad: bool = False, |
|
|
): |
|
|
if requires_grad: |
|
|
prompt_embeds = encode_prompt( |
|
|
tokenizer, |
|
|
text_encoder, |
|
|
prompt, |
|
|
num_videos_per_prompt=1, |
|
|
max_sequence_length=max_sequence_length, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
prompt_embeds = encode_prompt( |
|
|
tokenizer, |
|
|
text_encoder, |
|
|
prompt, |
|
|
num_videos_per_prompt=1, |
|
|
max_sequence_length=max_sequence_length, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
return prompt_embeds |
|
|
|