File size: 8,497 Bytes
19d8873 19f124c 69ab272 19d8873 19f124c 19d8873 19f124c 19d8873 19f124c 19d8873 19f124c 19d8873 19f124c 0021de3 19f124c 19d8873 19f124c 19d8873 19f124c 19d8873 19f124c 19d8873 19f124c 69ab272 19d8873 19f124c 19d8873 0021de3 69ab272 19d8873 935404c 19d8873 69ab272 19d8873 69ab272 19d8873 69ab272 19d8873 69ab272 19d8873 69ab272 19d8873 69ab272 19d8873 69ab272 19d8873 69ab272 19d8873 69ab272 19d8873 0021de3 69ab272 0021de3 69ab272 19d8873 69ab272 0021de3 19d8873 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# modeling_emuru.py
import torch
import torch.nn as nn
from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
from configuration_emuru import EmuruConfig
# from .configuration_emuru import EmuruConfig
from diffusers import AutoencoderKL
from einops.layers.torch import Rearrange
from einops import rearrange, repeat
class Emuru(PreTrainedModel):
config_class = EmuruConfig # Link to your configuration
def __init__(self, config):
super().__init__(config)
# Initialize the tokenizer (if you want it as part of your model)
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_config)
# Load T5 using the provided filename from config
t5_config = T5Config.from_pretrained(config.t5_config)
t5_config.vocab_size = len(self.tokenizer)
self.T5 = T5ForConditionalGeneration(t5_config)
self.T5.lm_head = nn.Identity()
self.sos = nn.Embedding(1, t5_config.d_model)
vae_latent_size = 8 * config.vae_channels * config.slices_per_query
self.vae_to_t5 = nn.Linear(vae_latent_size, t5_config.d_model)
self.t5_to_vae = nn.Linear(t5_config.d_model, vae_latent_size, bias=False)
self.padding_token = nn.Parameter(torch.empty((1, vae_latent_size)), requires_grad=False)
self.padding_token_threshold = nn.Parameter(torch.empty(1), requires_grad=False)
# Load VAE
self.vae = AutoencoderKL.from_pretrained(config.vae_config)
self.set_training(self.vae, False)
# Define the rearrange layers
self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
# Define your loss functions
self.mse_criterion = nn.MSELoss()
# Initialize weights following Hugging Face conventions (if needed)
self.init_weights()
def set_training(self, model, training):
model.train() if training else model.eval()
for param in model.parameters():
param.requires_grad = training
# --- Implement the rest of your methods ---
# For example, _img_encode, forward, generate, etc.
# You can largely port your existing code here, making sure that:
# - The forward method returns a dictionary with your losses and outputs.
# - You use the Hugging Face methods for saving/loading weights.
def forward(self, img=None, input_ids=None, attention_mask=None, noise=0, **kwargs):
decoder_inputs_embeds, z_sequence, z = self._img_encode(img, noise)
output = self.T5(input_ids, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds)
vae_latent = self.t5_to_vae(output.logits[:, :-1])
pred_latent = self.z_rearrange(vae_latent)
mse_loss = self.mse_criterion(vae_latent, z_sequence)
return mse_loss, pred_latent, z
def old_generate(self, text=None, img=None, z_sequence=None, input_ids=None, max_new_tokens=256,
stopping_criteria='latent', stopping_after=10, stopping_errors=1):
assert text is not None or input_ids is not None, 'Either text or input_ids must be provided'
assert img is not None or z_sequence is not None, 'Either img or z_sequence must be provided'
if input_ids is None:
input_ids = self.tokenizer(text, return_tensors='pt', padding=True).input_ids
input_ids = input_ids.to(next(self.T5.parameters()).device)
if z_sequence is None:
_, z_sequence, _ = self._img_encode(img)
z_sequence = [z_sequence]
sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
for _ in range(max_new_tokens):
if len(z_sequence) == 0:
decoder_inputs_embeds = sos
else:
decoder_inputs_embeds = self.vae_to_t5(torch.cat(z_sequence, dim=1))
decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
vae_latent = self.t5_to_vae(output.logits[:, -1:])
z_sequence.append(vae_latent)
if stopping_criteria == 'latent':
curr_z_sequence = torch.cat(z_sequence, dim=1)
pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0)).to(decoder_inputs_embeds.device)
similarity = torch.nn.functional.cosine_similarity(curr_z_sequence, pad_token, dim=-1)
similarity = similarity[:, -stopping_after:] > self.padding_token_threshold
if torch.all(similarity.sum(-1) >= (stopping_after - stopping_errors)):
# z_sequence = [curr_z_sequence[:, :-stopping_after]]
z_sequence = [curr_z_sequence]
break
elif stopping_criteria == 'pixel':
raise NotImplementedError
z_sequence = torch.cat(z_sequence, dim=1)
img = torch.clamp(self.vae.decode(self.z_rearrange(z_sequence)).sample, -1, 1)
return img
def generate(self,
style_text=None,
gen_text=None,
style_img=None,
input_ids=None,
z_sequence=None,
max_new_tokens=256,
stopping_criteria='latent',
stopping_after=10,
stopping_patience=1,
trim_image=True):
assert (gen_text is not None and style_text is not None) or input_ids is not None, 'Either gen_text and style_text or input_ids must be provided'
assert style_img is not None or z_sequence is not None, 'Either style_img or z_sequence must be provided'
if input_ids is None:
input_ids = self.tokenizer(gen_text + ' ' + style_text, return_tensors='pt', padding=True).input_ids
input_ids = input_ids.to(self.device)
if z_sequence is None:
_, z_sequence, _ = self._img_encode(style_img)
z_sequence = [z_sequence]
sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0))
for _ in range(max_new_tokens):
if len(z_sequence) == 0:
decoder_inputs_embeds = sos
else:
decoder_inputs_embeds = self.vae_to_t5(torch.cat(z_sequence, dim=1))
decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
vae_latent = self.t5_to_vae(output.logits[:, -1:])
z_sequence.append(vae_latent)
if stopping_criteria == 'latent':
curr_z_sequence = torch.cat(z_sequence, dim=1)
similarity = torch.nn.functional.cosine_similarity(curr_z_sequence, pad_token, dim=-1)
similarity = similarity[:, -stopping_after:] > self.padding_token_threshold
if torch.all(similarity.sum(-1) >= (stopping_after - stopping_patience)):
z_sequence = [curr_z_sequence[:, :-similarity.sum(-1)]] if trim_image else [curr_z_sequence]
break
elif stopping_criteria == 'pixel':
raise NotImplementedError
z_sequence = torch.cat(z_sequence, dim=1)
img = torch.clamp(self.vae.decode(self.z_rearrange(z_sequence)).sample, -1, 1)
return img, z_sequence
def _img_encode(self, img, noise=0):
posterior = self.vae.encode(img.float())
z = posterior.latent_dist.sample()
z_sequence = self.query_rearrange(z)
noise_sequence = z_sequence
if noise > 0:
noise_sequence = z_sequence + torch.randn_like(z_sequence) * noise
decoder_inputs_embeds = self.vae_to_t5(noise_sequence)
sos = repeat(self.sos.weight, '1 d -> b 1 d', b=decoder_inputs_embeds.size(0))
decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
return decoder_inputs_embeds, z_sequence, z
def compute_padding_token(self):
raise NotImplementedError("compute_padding_token not implemented")
def compute_padding_token_threshold(self):
raise NotImplementedError("compute_padding_token_threshold not implemented") |