Vittorio Pippi commited on
Commit
5f68c1b
·
1 Parent(s): 19d8873

Rename configuration parameters in EmuruConfig for clarity

Browse files
Files changed (3) hide show
  1. configuration_emuru.py +6 -6
  2. modeling_emuru.bkp.py +0 -316
  3. modeling_emuru.py +235 -101
configuration_emuru.py CHANGED
@@ -4,15 +4,15 @@ class EmuruConfig(PretrainedConfig):
4
  model_type = "emuru"
5
 
6
  def __init__(self,
7
- t5_config='google-t5/t5-large',
8
- vae_config='blowing-up-groundhogs/emuru_vae',
9
- tokenizer_config='google/byt5-small',
10
  slices_per_query=1,
11
  vae_channels=1,
12
  **kwargs):
13
  super().__init__(**kwargs)
14
- self.t5_config = t5_config
15
- self.vae_config = vae_config
16
- self.tokenizer_config = tokenizer_config
17
  self.slices_per_query = slices_per_query
18
  self.vae_channels = vae_channels
 
4
  model_type = "emuru"
5
 
6
  def __init__(self,
7
+ t5_name_or_path='google-t5/t5-large',
8
+ vae_name_or_path='blowing-up-groundhogs/emuru_vae',
9
+ tokenizer_name_or_path='google/byt5-small',
10
  slices_per_query=1,
11
  vae_channels=1,
12
  **kwargs):
13
  super().__init__(**kwargs)
14
+ self.t5_name_or_path = t5_name_or_path
15
+ self.vae_name_or_path = vae_name_or_path
16
+ self.tokenizer_name_or_path = tokenizer_name_or_path
17
  self.slices_per_query = slices_per_query
18
  self.vae_channels = vae_channels
modeling_emuru.bkp.py DELETED
@@ -1,316 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
4
- from configuration_emuru import EmuruConfig
5
- from diffusers import AutoencoderKL
6
- from einops.layers.torch import Rearrange
7
- from einops import repeat
8
- from torchvision.transforms import functional as F
9
- from typing import Optional, Tuple, List, Any
10
- from PIL import Image
11
-
12
- class Emuru(PreTrainedModel):
13
- config_class = EmuruConfig
14
- """
15
- Emuru is a conditional generative model that integrates a T5-based decoder with a VAE
16
- for image generation conditioned on text and style images.
17
-
18
- Attributes:
19
- config_class (Type): Configuration class for the model.
20
- tokenizer (AutoTokenizer): Tokenizer loaded from the provided tokenizer configuration.
21
- T5 (T5ForConditionalGeneration): T5 model adapted for conditional generation.
22
- sos (nn.Embedding): Start-of-sequence embedding.
23
- vae_to_t5 (nn.Linear): Linear projection from VAE latent space to T5 hidden space.
24
- t5_to_vae (nn.Linear): Linear projection from T5 hidden space back to VAE latent space.
25
- padding_token (nn.Parameter): Non-trainable parameter for padding tokens.
26
- padding_token_threshold (nn.Parameter): Non-trainable parameter for padding token threshold.
27
- vae (AutoencoderKL): Pre-trained Variational Autoencoder.
28
- query_rearrange (Rearrange): Layer to rearrange VAE latent representations for queries.
29
- z_rearrange (Rearrange): Layer to rearrange T5 outputs back to VAE latent dimensions.
30
- mse_criterion (nn.MSELoss): Mean squared error loss function.
31
- """
32
-
33
- def __init__(self, config: EmuruConfig) -> None:
34
- """
35
- Initialize the Emuru model.
36
-
37
- Args:
38
- config (EmuruConfig): Configuration object containing model hyperparameters and paths.
39
- """
40
- super().__init__(config)
41
-
42
- self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_config)
43
-
44
- t5_config = T5Config.from_pretrained(config.t5_config)
45
- t5_config.vocab_size = len(self.tokenizer)
46
- self.T5 = T5ForConditionalGeneration(t5_config)
47
- self.T5.lm_head = nn.Identity()
48
- self.sos = nn.Embedding(1, t5_config.d_model)
49
-
50
- vae_latent_size = 8 * config.vae_channels * config.slices_per_query
51
- self.vae_to_t5 = nn.Linear(vae_latent_size, t5_config.d_model)
52
- self.t5_to_vae = nn.Linear(t5_config.d_model, vae_latent_size, bias=False)
53
-
54
- self.padding_token = nn.Parameter(torch.empty((1, vae_latent_size)), requires_grad=False)
55
- self.padding_token_threshold = nn.Parameter(torch.empty(1), requires_grad=False)
56
-
57
- self.vae = AutoencoderKL.from_pretrained(config.vae_config)
58
- self.set_training(self.vae, False)
59
-
60
- self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
61
- self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
62
-
63
- self.mse_criterion = nn.MSELoss()
64
- self.init_weights()
65
-
66
- def set_training(self, model: nn.Module, training: bool) -> None:
67
- """
68
- Set the training mode for a given model and freeze/unfreeze parameters accordingly.
69
-
70
- Args:
71
- model (nn.Module): The model to set the training mode for.
72
- training (bool): If True, set the model to training mode; otherwise, evaluation mode.
73
- """
74
- model.train() if training else model.eval()
75
- for param in model.parameters():
76
- param.requires_grad = training
77
-
78
- def forward(
79
- self,
80
- img: Optional[torch.Tensor] = None,
81
- input_ids: Optional[torch.Tensor] = None,
82
- attention_mask: Optional[torch.Tensor] = None,
83
- noise: float = 0,
84
- **kwargs: Any
85
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
86
- """
87
- Forward pass of the model.
88
-
89
- Args:
90
- img (Optional[torch.Tensor]): Input image tensor.
91
- input_ids (Optional[torch.Tensor]): Tokenized input IDs.
92
- attention_mask (Optional[torch.Tensor]): Attention mask for the inputs.
93
- noise (float): Amount of noise to add in image encoding.
94
- **kwargs: Additional arguments.
95
-
96
- Returns:
97
- Tuple containing:
98
- - mse_loss (torch.Tensor): Mean squared error loss.
99
- - pred_latent (torch.Tensor): Predicted latent representations.
100
- - z (torch.Tensor): Sampled latent vector from VAE.
101
- """
102
- decoder_inputs_embeds, z_sequence, z = self._img_encode(img, noise)
103
-
104
- output = self.T5(input_ids, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds)
105
- vae_latent = self.t5_to_vae(output.logits[:, :-1])
106
- pred_latent = self.z_rearrange(vae_latent)
107
-
108
- mse_loss = self.mse_criterion(vae_latent, z_sequence)
109
- return mse_loss, pred_latent, z
110
-
111
- def generate(
112
- self,
113
- style_text: str,
114
- gen_text: str,
115
- style_img: torch.Tensor,
116
- **kwargs: Any
117
- ) -> Image.Image:
118
- """
119
- Generate an image by combining style and generation texts with a style image.
120
-
121
- Args:
122
- style_text (str): Style-related text prompt.
123
- gen_text (str): Generation-related text prompt.
124
- style_img (torch.Tensor): Style image tensor. Expected shape is either 3D or 4D.
125
- **kwargs: Additional keyword arguments.
126
-
127
- Returns:
128
- Image.Image: Generated image as a PIL image.
129
- """
130
- if style_img.ndim == 3:
131
- style_img = style_img.unsqueeze(0)
132
- elif style_img.ndim == 4:
133
- pass
134
- else:
135
- raise ValueError('style_img must be 3D or 4D')
136
-
137
- texts = [style_text + ' ' + gen_text]
138
- imgs, _, img_ends = self._generate(texts=texts, imgs=style_img, **kwargs)
139
- imgs = (imgs + 1) / 2
140
- return F.to_pil_image(imgs[0, ..., style_img.size(-1):img_ends.item()].detach().cpu())
141
-
142
- def generate_batch(
143
- self,
144
- style_texts: List[str],
145
- gen_texts: List[str],
146
- style_imgs: torch.Tensor,
147
- lengths: List[int],
148
- **kwargs: Any
149
- ) -> List[Image.Image]:
150
- """
151
- Generate a batch of images from lists of style texts, generation texts, and style images.
152
-
153
- Args:
154
- style_texts (List[str]): List of style-related text prompts.
155
- gen_texts (List[str]): List of generation-related text prompts.
156
- style_imgs (torch.Tensor): Batch of style images (4D tensor).
157
- lengths (List[int]): List of lengths corresponding to each image.
158
- **kwargs: Additional keyword arguments.
159
-
160
- Returns:
161
- List[Image.Image]: List of generated images as PIL images.
162
- """
163
- assert style_imgs.ndim == 4, 'style_imgs must be 4D'
164
- assert len(style_texts) == len(style_imgs), 'style_texts and style_imgs must have the same length'
165
- assert len(gen_texts) == len(style_imgs), 'gen_texts and style_imgs must have the same length'
166
- texts = [style_text + ' ' + gen_text for style_text, gen_text in zip(style_texts, gen_texts)]
167
-
168
- imgs, _, img_ends = self._generate(texts=texts, imgs=style_imgs, lengths=lengths, **kwargs)
169
- imgs = (imgs + 1) / 2
170
-
171
- out_imgs = []
172
- for i, end in enumerate(img_ends):
173
- start = lengths[i]
174
- out_imgs.append(F.to_pil_image(imgs[i, ..., start:end].detach().cpu()))
175
- return out_imgs
176
-
177
- def _generate(
178
- self,
179
- texts: Optional[List[str]] = None,
180
- imgs: Optional[torch.Tensor] = None,
181
- lengths: Optional[List[int]] = None,
182
- input_ids: Optional[torch.Tensor] = None,
183
- z_sequence: Optional[torch.Tensor] = None,
184
- max_new_tokens: int = 256,
185
- stopping_criteria: str = 'latent',
186
- stopping_after: int = 10,
187
- stopping_patience: int = 1
188
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
189
- """
190
- Internal generation routine that combines textual and visual inputs to iteratively generate
191
- latent representations and decode them into images.
192
-
193
- Args:
194
- texts (Optional[List[str]]): List of text prompts.
195
- imgs (Optional[torch.Tensor]): Input image tensor.
196
- lengths (Optional[List[int]]): Desired lengths for each image in latent space.
197
- input_ids (Optional[torch.Tensor]): Tokenized input IDs.
198
- z_sequence (Optional[torch.Tensor]): Precomputed latent sequence.
199
- max_new_tokens (int): Maximum tokens to generate.
200
- stopping_criteria (str): Criteria for stopping ('latent' or 'none').
201
- stopping_after (int): Number of tokens to check for stopping condition.
202
- stopping_patience (int): Patience parameter for stopping condition.
203
-
204
- Returns:
205
- Tuple containing:
206
- - imgs (torch.Tensor): Generated images.
207
- - canvas_sequence (torch.Tensor): Generated latent canvas sequence.
208
- - img_ends (torch.Tensor): End indices for each generated image.
209
- """
210
- assert texts is not None or input_ids is not None, 'Either texts or input_ids must be provided'
211
- assert imgs is not None or z_sequence is not None, 'Either imgs or z_sequence must be provided'
212
-
213
- if input_ids is None:
214
- input_ids = self.tokenizer(texts, return_tensors='pt', padding=True).input_ids
215
- input_ids = input_ids.to(self.device)
216
-
217
- if z_sequence is None:
218
- _, z_sequence, _ = self._img_encode(imgs)
219
-
220
- if lengths is None:
221
- lengths = [imgs.size(-1)] * imgs.size(0)
222
- lengths = torch.tensor(lengths).to(self.device)
223
- lengths = (lengths / 8).ceil().int()
224
-
225
- z_sequence_mask = torch.zeros((z_sequence.size(0), lengths.max() + max_new_tokens))
226
- z_sequence_mask = z_sequence_mask.bool().to(self.device)
227
- for i, l in enumerate(lengths):
228
- z_sequence_mask[i, :l] = True
229
-
230
- canvas_sequence = z_sequence[:, :lengths.min()]
231
- sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
232
- pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0))
233
- seq_stops = torch.ones(z_sequence.size(0), dtype=torch.int) * -1
234
-
235
- for token_idx in range(lengths.min(), lengths.max() + max_new_tokens):
236
- if len(z_sequence) == 0:
237
- decoder_inputs_embeds = sos
238
- else:
239
- decoder_inputs_embeds = self.vae_to_t5(canvas_sequence)
240
- decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
241
- output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
242
- vae_latent = self.t5_to_vae(output.logits[:, -1:])
243
-
244
- mask_slice = z_sequence_mask[:, token_idx].unsqueeze(-1)
245
- if token_idx < z_sequence.size(1):
246
- seq_slice = torch.where(mask_slice, z_sequence[:, token_idx], vae_latent[:, 0])
247
- else:
248
- seq_slice = vae_latent[:, 0]
249
- canvas_sequence = torch.cat([canvas_sequence, seq_slice.unsqueeze(1)], dim=1)
250
-
251
- if stopping_criteria == 'latent':
252
- similarity = torch.nn.functional.cosine_similarity(canvas_sequence, pad_token, dim=-1)
253
- windows = (similarity > self.padding_token_threshold).unfold(1, stopping_after, 1)
254
- window_sums = windows.to(torch.int).sum(dim=2)
255
-
256
- for i in range(similarity.size(0)):
257
- idx = (window_sums[i] > (stopping_after - stopping_patience)).nonzero(as_tuple=True)[0]
258
- if idx.numel() > 0:
259
- seq_stops[i] = idx[0].item()
260
-
261
- if torch.all(seq_stops >= 0):
262
- break
263
- elif stopping_criteria == 'none':
264
- pass
265
-
266
- imgs = torch.clamp(self.vae.decode(self.z_rearrange(canvas_sequence)).sample, -1, 1)
267
- return imgs, canvas_sequence, seq_stops * 8
268
-
269
- def _img_encode(
270
- self,
271
- img: torch.Tensor,
272
- noise: float = 0
273
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
274
- """
275
- Encode the input image into a latent representation using the VAE.
276
-
277
- Args:
278
- img (torch.Tensor): Input image tensor.
279
- noise (float): Standard deviation of noise to add to the latent sequence.
280
-
281
- Returns:
282
- Tuple containing:
283
- - decoder_inputs_embeds (torch.Tensor): Embeddings to be used as T5 decoder inputs.
284
- - z_sequence (torch.Tensor): Rearranged latent sequence from the VAE.
285
- - z (torch.Tensor): Sampled latent vector from the VAE.
286
- """
287
- posterior = self.vae.encode(img.float())
288
- z = posterior.latent_dist.sample()
289
- z_sequence = self.query_rearrange(z)
290
-
291
- noise_sequence = z_sequence
292
- if noise > 0:
293
- noise_sequence = z_sequence + torch.randn_like(z_sequence) * noise
294
-
295
- decoder_inputs_embeds = self.vae_to_t5(noise_sequence)
296
- sos = repeat(self.sos.weight, '1 d -> b 1 d', b=decoder_inputs_embeds.size(0))
297
- decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
298
- return decoder_inputs_embeds, z_sequence, z
299
-
300
- def compute_padding_token(self) -> None:
301
- """
302
- Compute and update the padding token.
303
-
304
- Raises:
305
- NotImplementedError: This method must be implemented.
306
- """
307
- raise NotImplementedError("compute_padding_token not implemented")
308
-
309
- def compute_padding_token_threshold(self) -> None:
310
- """
311
- Compute and update the padding token threshold.
312
-
313
- Raises:
314
- NotImplementedError: This method must be implemented.
315
- """
316
- raise NotImplementedError("compute_padding_token_threshold not implemented")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_emuru.py CHANGED
@@ -1,23 +1,47 @@
1
- # modeling_emuru.py
2
  import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
5
- from configuration_emuru import EmuruConfig
6
- # from .configuration_emuru import EmuruConfig
7
  from diffusers import AutoencoderKL
8
  from einops.layers.torch import Rearrange
9
- from einops import rearrange, repeat
 
 
 
10
 
11
  class Emuru(PreTrainedModel):
12
- config_class = EmuruConfig # Link to your configuration
13
-
14
- def __init__(self, config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  super().__init__(config)
16
- # Initialize the tokenizer (if you want it as part of your model)
17
- self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_config)
18
 
19
- # Load T5 using the provided filename from config
20
- t5_config = T5Config.from_pretrained(config.t5_config)
21
  t5_config.vocab_size = len(self.tokenizer)
22
  self.T5 = T5ForConditionalGeneration(t5_config)
23
  self.T5.lm_head = nn.Identity()
@@ -30,34 +54,51 @@ class Emuru(PreTrainedModel):
30
  self.padding_token = nn.Parameter(torch.empty((1, vae_latent_size)), requires_grad=False)
31
  self.padding_token_threshold = nn.Parameter(torch.empty(1), requires_grad=False)
32
 
33
- # Load VAE
34
- self.vae = AutoencoderKL.from_pretrained(config.vae_config)
35
  self.set_training(self.vae, False)
36
 
37
- # Define the rearrange layers
38
  self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
39
  self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
40
 
41
- # Define your loss functions
42
  self.mse_criterion = nn.MSELoss()
43
-
44
- # Initialize weights following Hugging Face conventions (if needed)
45
  self.init_weights()
46
 
 
 
 
47
 
48
- def set_training(self, model, training):
 
 
 
49
  model.train() if training else model.eval()
50
  for param in model.parameters():
51
  param.requires_grad = training
52
 
53
- # --- Implement the rest of your methods ---
54
- # For example, _img_encode, forward, generate, etc.
55
- # You can largely port your existing code here, making sure that:
56
- # - The forward method returns a dictionary with your losses and outputs.
57
- # - You use the Hugging Face methods for saving/loading weights.
58
-
59
-
60
- def forward(self, img=None, input_ids=None, attention_mask=None, noise=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  decoder_inputs_embeds, z_sequence, z = self._img_encode(img, noise)
62
 
63
  output = self.T5(input_ids, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds)
@@ -67,99 +108,182 @@ class Emuru(PreTrainedModel):
67
  mse_loss = self.mse_criterion(vae_latent, z_sequence)
68
  return mse_loss, pred_latent, z
69
 
70
-
71
- def old_generate(self, text=None, img=None, z_sequence=None, input_ids=None, max_new_tokens=256,
72
- stopping_criteria='latent', stopping_after=10, stopping_errors=1):
73
- assert text is not None or input_ids is not None, 'Either text or input_ids must be provided'
74
- assert img is not None or z_sequence is not None, 'Either img or z_sequence must be provided'
75
-
76
- if input_ids is None:
77
- input_ids = self.tokenizer(text, return_tensors='pt', padding=True).input_ids
78
- input_ids = input_ids.to(next(self.T5.parameters()).device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- if z_sequence is None:
81
- _, z_sequence, _ = self._img_encode(img)
82
- z_sequence = [z_sequence]
83
-
84
- sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
85
- for _ in range(max_new_tokens):
86
- if len(z_sequence) == 0:
87
- decoder_inputs_embeds = sos
88
- else:
89
- decoder_inputs_embeds = self.vae_to_t5(torch.cat(z_sequence, dim=1))
90
- decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
91
- output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
92
- vae_latent = self.t5_to_vae(output.logits[:, -1:])
93
- z_sequence.append(vae_latent)
94
-
95
- if stopping_criteria == 'latent':
96
- curr_z_sequence = torch.cat(z_sequence, dim=1)
97
- pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0)).to(decoder_inputs_embeds.device)
98
- similarity = torch.nn.functional.cosine_similarity(curr_z_sequence, pad_token, dim=-1)
99
- similarity = similarity[:, -stopping_after:] > self.padding_token_threshold
100
- if torch.all(similarity.sum(-1) >= (stopping_after - stopping_errors)):
101
- # z_sequence = [curr_z_sequence[:, :-stopping_after]]
102
- z_sequence = [curr_z_sequence]
103
- break
104
- elif stopping_criteria == 'pixel':
105
- raise NotImplementedError
106
-
107
- z_sequence = torch.cat(z_sequence, dim=1)
108
- img = torch.clamp(self.vae.decode(self.z_rearrange(z_sequence)).sample, -1, 1)
109
- return img
110
-
111
-
112
- def generate(self,
113
- style_text=None,
114
- gen_text=None,
115
- style_img=None,
116
- input_ids=None,
117
- z_sequence=None,
118
- max_new_tokens=256,
119
- stopping_criteria='latent',
120
- stopping_after=10,
121
- stopping_patience=1,
122
- trim_image=True):
123
- assert (gen_text is not None and style_text is not None) or input_ids is not None, 'Either gen_text and style_text or input_ids must be provided'
124
- assert style_img is not None or z_sequence is not None, 'Either style_img or z_sequence must be provided'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  if input_ids is None:
127
- input_ids = self.tokenizer(gen_text + ' ' + style_text, return_tensors='pt', padding=True).input_ids
128
  input_ids = input_ids.to(self.device)
129
 
130
  if z_sequence is None:
131
- _, z_sequence, _ = self._img_encode(style_img)
132
- z_sequence = [z_sequence]
 
 
 
 
 
 
 
 
 
133
 
 
134
  sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
135
  pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0))
 
136
 
137
- for _ in range(max_new_tokens):
138
  if len(z_sequence) == 0:
139
  decoder_inputs_embeds = sos
140
  else:
141
- decoder_inputs_embeds = self.vae_to_t5(torch.cat(z_sequence, dim=1))
142
  decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
143
  output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
144
  vae_latent = self.t5_to_vae(output.logits[:, -1:])
145
- z_sequence.append(vae_latent)
 
 
 
 
 
 
146
 
147
  if stopping_criteria == 'latent':
148
- curr_z_sequence = torch.cat(z_sequence, dim=1)
149
- similarity = torch.nn.functional.cosine_similarity(curr_z_sequence, pad_token, dim=-1)
150
- similarity = similarity[:, -stopping_after:] > self.padding_token_threshold
151
- if torch.all(similarity.sum(-1) >= (stopping_after - stopping_patience)):
152
- z_sequence = [curr_z_sequence[:, :-similarity.sum(-1)]] if trim_image else [curr_z_sequence]
 
 
 
 
 
153
  break
154
- elif stopping_criteria == 'pixel':
155
- raise NotImplementedError
156
 
157
- z_sequence = torch.cat(z_sequence, dim=1)
158
- img = torch.clamp(self.vae.decode(self.z_rearrange(z_sequence)).sample, -1, 1)
159
- return img, z_sequence
160
 
161
-
162
- def _img_encode(self, img, noise=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  posterior = self.vae.encode(img.float())
164
  z = posterior.latent_dist.sample()
165
  z_sequence = self.query_rearrange(z)
@@ -173,10 +297,20 @@ class Emuru(PreTrainedModel):
173
  decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
174
  return decoder_inputs_embeds, z_sequence, z
175
 
 
 
 
176
 
177
- def compute_padding_token(self):
 
 
178
  raise NotImplementedError("compute_padding_token not implemented")
179
 
 
 
 
180
 
181
- def compute_padding_token_threshold(self):
182
- raise NotImplementedError("compute_padding_token_threshold not implemented")
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
4
+ from .configuration_emuru import EmuruConfig
 
5
  from diffusers import AutoencoderKL
6
  from einops.layers.torch import Rearrange
7
+ from einops import repeat
8
+ from torchvision.transforms import functional as F
9
+ from typing import Optional, Tuple, List, Any
10
+ from PIL import Image
11
 
12
  class Emuru(PreTrainedModel):
13
+ """
14
+ Emuru is a conditional generative model that integrates a T5-based decoder with a VAE
15
+ for image generation conditioned on text and style images.
16
+
17
+ Attributes:
18
+ config_class (Type): Configuration class for the model.
19
+ tokenizer (AutoTokenizer): Tokenizer loaded from the provided tokenizer configuration.
20
+ T5 (T5ForConditionalGeneration): T5 model adapted for conditional generation.
21
+ sos (nn.Embedding): Start-of-sequence embedding.
22
+ vae_to_t5 (nn.Linear): Linear projection from VAE latent space to T5 hidden space.
23
+ t5_to_vae (nn.Linear): Linear projection from T5 hidden space back to VAE latent space.
24
+ padding_token (nn.Parameter): Non-trainable parameter for padding tokens.
25
+ padding_token_threshold (nn.Parameter): Non-trainable parameter for padding token threshold.
26
+ vae (AutoencoderKL): Pre-trained Variational Autoencoder.
27
+ query_rearrange (Rearrange): Layer to rearrange VAE latent representations for queries.
28
+ z_rearrange (Rearrange): Layer to rearrange T5 outputs back to VAE latent dimensions.
29
+ mse_criterion (nn.MSELoss): Mean squared error loss function.
30
+ """
31
+ config_class = EmuruConfig
32
+
33
+ def __init__(self, config: EmuruConfig) -> None:
34
+ """
35
+ Initialize the Emuru model.
36
+
37
+ Args:
38
+ config (EmuruConfig): Configuration object containing model hyperparameters and paths.
39
+ """
40
  super().__init__(config)
41
+
42
+ self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
43
 
44
+ t5_config = T5Config.from_pretrained(config.t5_name_or_path)
 
45
  t5_config.vocab_size = len(self.tokenizer)
46
  self.T5 = T5ForConditionalGeneration(t5_config)
47
  self.T5.lm_head = nn.Identity()
 
54
  self.padding_token = nn.Parameter(torch.empty((1, vae_latent_size)), requires_grad=False)
55
  self.padding_token_threshold = nn.Parameter(torch.empty(1), requires_grad=False)
56
 
57
+ self.vae = AutoencoderKL.from_pretrained(config.vae_name_or_path)
 
58
  self.set_training(self.vae, False)
59
 
 
60
  self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
61
  self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
62
 
 
63
  self.mse_criterion = nn.MSELoss()
 
 
64
  self.init_weights()
65
 
66
+ def set_training(self, model: nn.Module, training: bool) -> None:
67
+ """
68
+ Set the training mode for a given model and freeze/unfreeze parameters accordingly.
69
 
70
+ Args:
71
+ model (nn.Module): The model to set the training mode for.
72
+ training (bool): If True, set the model to training mode; otherwise, evaluation mode.
73
+ """
74
  model.train() if training else model.eval()
75
  for param in model.parameters():
76
  param.requires_grad = training
77
 
78
+ def forward(
79
+ self,
80
+ img: Optional[torch.Tensor] = None,
81
+ input_ids: Optional[torch.Tensor] = None,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ noise: float = 0,
84
+ **kwargs: Any
85
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
86
+ """
87
+ Forward pass of the model.
88
+
89
+ Args:
90
+ img (Optional[torch.Tensor]): Input image tensor.
91
+ input_ids (Optional[torch.Tensor]): Tokenized input IDs.
92
+ attention_mask (Optional[torch.Tensor]): Attention mask for the inputs.
93
+ noise (float): Amount of noise to add in image encoding.
94
+ **kwargs: Additional arguments.
95
+
96
+ Returns:
97
+ Tuple containing:
98
+ - mse_loss (torch.Tensor): Mean squared error loss.
99
+ - pred_latent (torch.Tensor): Predicted latent representations.
100
+ - z (torch.Tensor): Sampled latent vector from VAE.
101
+ """
102
  decoder_inputs_embeds, z_sequence, z = self._img_encode(img, noise)
103
 
104
  output = self.T5(input_ids, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds)
 
108
  mse_loss = self.mse_criterion(vae_latent, z_sequence)
109
  return mse_loss, pred_latent, z
110
 
111
+ def generate(
112
+ self,
113
+ style_text: str,
114
+ gen_text: str,
115
+ style_img: torch.Tensor,
116
+ **kwargs: Any
117
+ ) -> Image.Image:
118
+ """
119
+ Generate an image by combining style and generation texts with a style image.
120
+
121
+ Args:
122
+ style_text (str): Style-related text prompt.
123
+ gen_text (str): Generation-related text prompt.
124
+ style_img (torch.Tensor): Style image tensor. Expected shape is either 3D or 4D.
125
+ **kwargs: Additional keyword arguments.
126
+
127
+ Returns:
128
+ Image.Image: Generated image as a PIL image.
129
+ """
130
+ if style_img.ndim == 3:
131
+ style_img = style_img.unsqueeze(0)
132
+ elif style_img.ndim == 4:
133
+ pass
134
+ else:
135
+ raise ValueError('style_img must be 3D or 4D')
136
 
137
+ texts = [style_text + ' ' + gen_text]
138
+ imgs, _, img_ends = self._generate(texts=texts, imgs=style_img, **kwargs)
139
+ imgs = (imgs + 1) / 2
140
+ return F.to_pil_image(imgs[0, ..., style_img.size(-1):img_ends.item()].detach().cpu())
141
+
142
+ def generate_batch(
143
+ self,
144
+ style_texts: List[str],
145
+ gen_texts: List[str],
146
+ style_imgs: torch.Tensor,
147
+ lengths: List[int],
148
+ **kwargs: Any
149
+ ) -> List[Image.Image]:
150
+ """
151
+ Generate a batch of images from lists of style texts, generation texts, and style images.
152
+
153
+ Args:
154
+ style_texts (List[str]): List of style-related text prompts.
155
+ gen_texts (List[str]): List of generation-related text prompts.
156
+ style_imgs (torch.Tensor): Batch of style images (4D tensor).
157
+ lengths (List[int]): List of lengths corresponding to each image.
158
+ **kwargs: Additional keyword arguments.
159
+
160
+ Returns:
161
+ List[Image.Image]: List of generated images as PIL images.
162
+ """
163
+ assert style_imgs.ndim == 4, 'style_imgs must be 4D'
164
+ assert len(style_texts) == len(style_imgs), 'style_texts and style_imgs must have the same length'
165
+ assert len(gen_texts) == len(style_imgs), 'gen_texts and style_imgs must have the same length'
166
+ texts = [style_text + ' ' + gen_text for style_text, gen_text in zip(style_texts, gen_texts)]
167
+
168
+ imgs, _, img_ends = self._generate(texts=texts, imgs=style_imgs, lengths=lengths, **kwargs)
169
+ imgs = (imgs + 1) / 2
170
+
171
+ out_imgs = []
172
+ for i, end in enumerate(img_ends):
173
+ start = lengths[i]
174
+ out_imgs.append(F.to_pil_image(imgs[i, ..., start:end].detach().cpu()))
175
+ return out_imgs
176
+
177
+ def _generate(
178
+ self,
179
+ texts: Optional[List[str]] = None,
180
+ imgs: Optional[torch.Tensor] = None,
181
+ lengths: Optional[List[int]] = None,
182
+ input_ids: Optional[torch.Tensor] = None,
183
+ z_sequence: Optional[torch.Tensor] = None,
184
+ max_new_tokens: int = 256,
185
+ stopping_criteria: str = 'latent',
186
+ stopping_after: int = 10,
187
+ stopping_patience: int = 1
188
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
189
+ """
190
+ Internal generation routine that combines textual and visual inputs to iteratively generate
191
+ latent representations and decode them into images.
192
+
193
+ Args:
194
+ texts (Optional[List[str]]): List of text prompts.
195
+ imgs (Optional[torch.Tensor]): Input image tensor.
196
+ lengths (Optional[List[int]]): Desired lengths for each image in latent space.
197
+ input_ids (Optional[torch.Tensor]): Tokenized input IDs.
198
+ z_sequence (Optional[torch.Tensor]): Precomputed latent sequence.
199
+ max_new_tokens (int): Maximum tokens to generate.
200
+ stopping_criteria (str): Criteria for stopping ('latent' or 'none').
201
+ stopping_after (int): Number of tokens to check for stopping condition.
202
+ stopping_patience (int): Patience parameter for stopping condition.
203
+
204
+ Returns:
205
+ Tuple containing:
206
+ - imgs (torch.Tensor): Generated images.
207
+ - canvas_sequence (torch.Tensor): Generated latent canvas sequence.
208
+ - img_ends (torch.Tensor): End indices for each generated image.
209
+ """
210
+ assert texts is not None or input_ids is not None, 'Either texts or input_ids must be provided'
211
+ assert imgs is not None or z_sequence is not None, 'Either imgs or z_sequence must be provided'
212
 
213
  if input_ids is None:
214
+ input_ids = self.tokenizer(texts, return_tensors='pt', padding=True).input_ids
215
  input_ids = input_ids.to(self.device)
216
 
217
  if z_sequence is None:
218
+ _, z_sequence, _ = self._img_encode(imgs)
219
+
220
+ if lengths is None:
221
+ lengths = [imgs.size(-1)] * imgs.size(0)
222
+ lengths = torch.tensor(lengths).to(self.device)
223
+ lengths = (lengths / 8).ceil().int()
224
+
225
+ z_sequence_mask = torch.zeros((z_sequence.size(0), lengths.max() + max_new_tokens))
226
+ z_sequence_mask = z_sequence_mask.bool().to(self.device)
227
+ for i, l in enumerate(lengths):
228
+ z_sequence_mask[i, :l] = True
229
 
230
+ canvas_sequence = z_sequence[:, :lengths.min()]
231
  sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
232
  pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0))
233
+ seq_stops = torch.ones(z_sequence.size(0), dtype=torch.int) * -1
234
 
235
+ for token_idx in range(lengths.min(), lengths.max() + max_new_tokens):
236
  if len(z_sequence) == 0:
237
  decoder_inputs_embeds = sos
238
  else:
239
+ decoder_inputs_embeds = self.vae_to_t5(canvas_sequence)
240
  decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
241
  output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
242
  vae_latent = self.t5_to_vae(output.logits[:, -1:])
243
+
244
+ mask_slice = z_sequence_mask[:, token_idx].unsqueeze(-1)
245
+ if token_idx < z_sequence.size(1):
246
+ seq_slice = torch.where(mask_slice, z_sequence[:, token_idx], vae_latent[:, 0])
247
+ else:
248
+ seq_slice = vae_latent[:, 0]
249
+ canvas_sequence = torch.cat([canvas_sequence, seq_slice.unsqueeze(1)], dim=1)
250
 
251
  if stopping_criteria == 'latent':
252
+ similarity = torch.nn.functional.cosine_similarity(canvas_sequence, pad_token, dim=-1)
253
+ windows = (similarity > self.padding_token_threshold).unfold(1, stopping_after, 1)
254
+ window_sums = windows.to(torch.int).sum(dim=2)
255
+
256
+ for i in range(similarity.size(0)):
257
+ idx = (window_sums[i] > (stopping_after - stopping_patience)).nonzero(as_tuple=True)[0]
258
+ if idx.numel() > 0:
259
+ seq_stops[i] = idx[0].item()
260
+
261
+ if torch.all(seq_stops >= 0):
262
  break
263
+ elif stopping_criteria == 'none':
264
+ pass
265
 
266
+ imgs = torch.clamp(self.vae.decode(self.z_rearrange(canvas_sequence)).sample, -1, 1)
267
+ return imgs, canvas_sequence, seq_stops * 8
 
268
 
269
+ def _img_encode(
270
+ self,
271
+ img: torch.Tensor,
272
+ noise: float = 0
273
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
274
+ """
275
+ Encode the input image into a latent representation using the VAE.
276
+
277
+ Args:
278
+ img (torch.Tensor): Input image tensor.
279
+ noise (float): Standard deviation of noise to add to the latent sequence.
280
+
281
+ Returns:
282
+ Tuple containing:
283
+ - decoder_inputs_embeds (torch.Tensor): Embeddings to be used as T5 decoder inputs.
284
+ - z_sequence (torch.Tensor): Rearranged latent sequence from the VAE.
285
+ - z (torch.Tensor): Sampled latent vector from the VAE.
286
+ """
287
  posterior = self.vae.encode(img.float())
288
  z = posterior.latent_dist.sample()
289
  z_sequence = self.query_rearrange(z)
 
297
  decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
298
  return decoder_inputs_embeds, z_sequence, z
299
 
300
+ def compute_padding_token(self) -> None:
301
+ """
302
+ Compute and update the padding token.
303
 
304
+ Raises:
305
+ NotImplementedError: This method must be implemented.
306
+ """
307
  raise NotImplementedError("compute_padding_token not implemented")
308
 
309
+ def compute_padding_token_threshold(self) -> None:
310
+ """
311
+ Compute and update the padding token threshold.
312
 
313
+ Raises:
314
+ NotImplementedError: This method must be implemented.
315
+ """
316
+ raise NotImplementedError("compute_padding_token_threshold not implemented")