Vittorio Pippi commited on
Commit
935404c
·
1 Parent(s): 69ab272

Inital commit

Browse files
.gitignore CHANGED
@@ -3,3 +3,6 @@ test.py
3
  model.py
4
  sample.png
5
  visual_prompting.py
 
 
 
 
3
  model.py
4
  sample.png
5
  visual_prompting.py
6
+ test.png
7
+ sample2.png
8
+ output.png
__pycache__/modeling_emuru.cpython-311.pyc CHANGED
Binary files a/__pycache__/modeling_emuru.cpython-311.pyc and b/__pycache__/modeling_emuru.cpython-311.pyc differ
 
modeling_emuru.py CHANGED
@@ -1,22 +1,46 @@
1
- # modeling_emuru.py
2
  import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
5
  from configuration_emuru import EmuruConfig
6
- # from .configuration_emuru import EmuruConfig
7
  from diffusers import AutoencoderKL
8
  from einops.layers.torch import Rearrange
9
- from einops import rearrange, repeat
 
 
 
10
 
11
  class Emuru(PreTrainedModel):
12
- config_class = EmuruConfig # Link to your configuration
13
-
14
- def __init__(self, config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  super().__init__(config)
16
- # Initialize the tokenizer (if you want it as part of your model)
17
  self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_config)
18
 
19
- # Load T5 using the provided filename from config
20
  t5_config = T5Config.from_pretrained(config.t5_config)
21
  t5_config.vocab_size = len(self.tokenizer)
22
  self.T5 = T5ForConditionalGeneration(t5_config)
@@ -30,34 +54,51 @@ class Emuru(PreTrainedModel):
30
  self.padding_token = nn.Parameter(torch.empty((1, vae_latent_size)), requires_grad=False)
31
  self.padding_token_threshold = nn.Parameter(torch.empty(1), requires_grad=False)
32
 
33
- # Load VAE
34
  self.vae = AutoencoderKL.from_pretrained(config.vae_config)
35
  self.set_training(self.vae, False)
36
 
37
- # Define the rearrange layers
38
  self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
39
  self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
40
 
41
- # Define your loss functions
42
  self.mse_criterion = nn.MSELoss()
43
-
44
- # Initialize weights following Hugging Face conventions (if needed)
45
  self.init_weights()
46
 
 
 
 
47
 
48
- def set_training(self, model, training):
 
 
 
49
  model.train() if training else model.eval()
50
  for param in model.parameters():
51
  param.requires_grad = training
52
 
53
- # --- Implement the rest of your methods ---
54
- # For example, _img_encode, forward, generate, etc.
55
- # You can largely port your existing code here, making sure that:
56
- # - The forward method returns a dictionary with your losses and outputs.
57
- # - You use the Hugging Face methods for saving/loading weights.
58
-
59
-
60
- def forward(self, img=None, input_ids=None, attention_mask=None, noise=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  decoder_inputs_embeds, z_sequence, z = self._img_encode(img, noise)
62
 
63
  output = self.T5(input_ids, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds)
@@ -67,99 +108,182 @@ class Emuru(PreTrainedModel):
67
  mse_loss = self.mse_criterion(vae_latent, z_sequence)
68
  return mse_loss, pred_latent, z
69
 
70
-
71
- def old_generate(self, text=None, img=None, z_sequence=None, input_ids=None, max_new_tokens=256,
72
- stopping_criteria='latent', stopping_after=10, stopping_errors=1):
73
- assert text is not None or input_ids is not None, 'Either text or input_ids must be provided'
74
- assert img is not None or z_sequence is not None, 'Either img or z_sequence must be provided'
75
-
76
- if input_ids is None:
77
- input_ids = self.tokenizer(text, return_tensors='pt', padding=True).input_ids
78
- input_ids = input_ids.to(next(self.T5.parameters()).device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- if z_sequence is None:
81
- _, z_sequence, _ = self._img_encode(img)
82
- z_sequence = [z_sequence]
83
-
84
- sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
85
- for _ in range(max_new_tokens):
86
- if len(z_sequence) == 0:
87
- decoder_inputs_embeds = sos
88
- else:
89
- decoder_inputs_embeds = self.vae_to_t5(torch.cat(z_sequence, dim=1))
90
- decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
91
- output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
92
- vae_latent = self.t5_to_vae(output.logits[:, -1:])
93
- z_sequence.append(vae_latent)
94
-
95
- if stopping_criteria == 'latent':
96
- curr_z_sequence = torch.cat(z_sequence, dim=1)
97
- pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0)).to(decoder_inputs_embeds.device)
98
- similarity = torch.nn.functional.cosine_similarity(curr_z_sequence, pad_token, dim=-1)
99
- similarity = similarity[:, -stopping_after:] > self.padding_token_threshold
100
- if torch.all(similarity.sum(-1) >= (stopping_after - stopping_errors)):
101
- # z_sequence = [curr_z_sequence[:, :-stopping_after]]
102
- z_sequence = [curr_z_sequence]
103
- break
104
- elif stopping_criteria == 'pixel':
105
- raise NotImplementedError
106
-
107
- z_sequence = torch.cat(z_sequence, dim=1)
108
- img = torch.clamp(self.vae.decode(self.z_rearrange(z_sequence)).sample, -1, 1)
109
- return img
110
-
111
-
112
- def generate(self,
113
- style_text=None,
114
- gen_text=None,
115
- style_img=None,
116
- input_ids=None,
117
- z_sequence=None,
118
- max_new_tokens=256,
119
- stopping_criteria='latent',
120
- stopping_after=10,
121
- stopping_patience=1,
122
- trim_image=True):
123
- assert (gen_text is not None and style_text is not None) or input_ids is not None, 'Either gen_text and style_text or input_ids must be provided'
124
- assert style_img is not None or z_sequence is not None, 'Either style_img or z_sequence must be provided'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  if input_ids is None:
127
- input_ids = self.tokenizer(gen_text + ' ' + style_text, return_tensors='pt', padding=True).input_ids
128
  input_ids = input_ids.to(self.device)
129
 
130
  if z_sequence is None:
131
- _, z_sequence, _ = self._img_encode(style_img)
132
- z_sequence = [z_sequence]
 
 
 
 
 
 
 
 
 
133
 
 
134
  sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
135
  pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0))
 
136
 
137
- for _ in range(max_new_tokens):
138
  if len(z_sequence) == 0:
139
  decoder_inputs_embeds = sos
140
  else:
141
- decoder_inputs_embeds = self.vae_to_t5(torch.cat(z_sequence, dim=1))
142
  decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
143
  output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
144
  vae_latent = self.t5_to_vae(output.logits[:, -1:])
145
- z_sequence.append(vae_latent)
 
 
 
 
 
 
146
 
147
  if stopping_criteria == 'latent':
148
- curr_z_sequence = torch.cat(z_sequence, dim=1)
149
- similarity = torch.nn.functional.cosine_similarity(curr_z_sequence, pad_token, dim=-1)
150
- similarity = similarity[:, -stopping_after:] > self.padding_token_threshold
151
- if torch.all(similarity.sum(-1) >= (stopping_after - stopping_patience)):
152
- z_sequence = [curr_z_sequence[:, :-similarity.sum(-1)]] if trim_image else [curr_z_sequence]
 
 
 
 
 
153
  break
154
- elif stopping_criteria == 'pixel':
155
- raise NotImplementedError
156
 
157
- z_sequence = torch.cat(z_sequence, dim=1)
158
- img = torch.clamp(self.vae.decode(self.z_rearrange(z_sequence)).sample, -1, 1)
159
- return img, z_sequence
160
 
161
-
162
- def _img_encode(self, img, noise=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  posterior = self.vae.encode(img.float())
164
  z = posterior.latent_dist.sample()
165
  z_sequence = self.query_rearrange(z)
@@ -173,10 +297,20 @@ class Emuru(PreTrainedModel):
173
  decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
174
  return decoder_inputs_embeds, z_sequence, z
175
 
 
 
 
176
 
177
- def compute_padding_token(self):
 
 
178
  raise NotImplementedError("compute_padding_token not implemented")
179
 
 
 
 
180
 
181
- def compute_padding_token_threshold(self):
 
 
182
  raise NotImplementedError("compute_padding_token_threshold not implemented")
 
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
4
  from configuration_emuru import EmuruConfig
 
5
  from diffusers import AutoencoderKL
6
  from einops.layers.torch import Rearrange
7
+ from einops import repeat
8
+ from torchvision.transforms import functional as F
9
+ from typing import Optional, Tuple, List, Any
10
+ from PIL import Image
11
 
12
  class Emuru(PreTrainedModel):
13
+ """
14
+ Emuru is a conditional generative model that integrates a T5-based decoder with a VAE
15
+ for image generation conditioned on text and style images.
16
+
17
+ Attributes:
18
+ config_class (Type): Configuration class for the model.
19
+ tokenizer (AutoTokenizer): Tokenizer loaded from the provided tokenizer configuration.
20
+ T5 (T5ForConditionalGeneration): T5 model adapted for conditional generation.
21
+ sos (nn.Embedding): Start-of-sequence embedding.
22
+ vae_to_t5 (nn.Linear): Linear projection from VAE latent space to T5 hidden space.
23
+ t5_to_vae (nn.Linear): Linear projection from T5 hidden space back to VAE latent space.
24
+ padding_token (nn.Parameter): Non-trainable parameter for padding tokens.
25
+ padding_token_threshold (nn.Parameter): Non-trainable parameter for padding token threshold.
26
+ vae (AutoencoderKL): Pre-trained Variational Autoencoder.
27
+ query_rearrange (Rearrange): Layer to rearrange VAE latent representations for queries.
28
+ z_rearrange (Rearrange): Layer to rearrange T5 outputs back to VAE latent dimensions.
29
+ mse_criterion (nn.MSELoss): Mean squared error loss function.
30
+ """
31
+ config_class = EmuruConfig
32
+
33
+ def __init__(self, config: EmuruConfig) -> None:
34
+ """
35
+ Initialize the Emuru model.
36
+
37
+ Args:
38
+ config (EmuruConfig): Configuration object containing model hyperparameters and paths.
39
+ """
40
  super().__init__(config)
41
+
42
  self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_config)
43
 
 
44
  t5_config = T5Config.from_pretrained(config.t5_config)
45
  t5_config.vocab_size = len(self.tokenizer)
46
  self.T5 = T5ForConditionalGeneration(t5_config)
 
54
  self.padding_token = nn.Parameter(torch.empty((1, vae_latent_size)), requires_grad=False)
55
  self.padding_token_threshold = nn.Parameter(torch.empty(1), requires_grad=False)
56
 
 
57
  self.vae = AutoencoderKL.from_pretrained(config.vae_config)
58
  self.set_training(self.vae, False)
59
 
 
60
  self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
61
  self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
62
 
 
63
  self.mse_criterion = nn.MSELoss()
 
 
64
  self.init_weights()
65
 
66
+ def set_training(self, model: nn.Module, training: bool) -> None:
67
+ """
68
+ Set the training mode for a given model and freeze/unfreeze parameters accordingly.
69
 
70
+ Args:
71
+ model (nn.Module): The model to set the training mode for.
72
+ training (bool): If True, set the model to training mode; otherwise, evaluation mode.
73
+ """
74
  model.train() if training else model.eval()
75
  for param in model.parameters():
76
  param.requires_grad = training
77
 
78
+ def forward(
79
+ self,
80
+ img: Optional[torch.Tensor] = None,
81
+ input_ids: Optional[torch.Tensor] = None,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ noise: float = 0,
84
+ **kwargs: Any
85
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
86
+ """
87
+ Forward pass of the model.
88
+
89
+ Args:
90
+ img (Optional[torch.Tensor]): Input image tensor.
91
+ input_ids (Optional[torch.Tensor]): Tokenized input IDs.
92
+ attention_mask (Optional[torch.Tensor]): Attention mask for the inputs.
93
+ noise (float): Amount of noise to add in image encoding.
94
+ **kwargs: Additional arguments.
95
+
96
+ Returns:
97
+ Tuple containing:
98
+ - mse_loss (torch.Tensor): Mean squared error loss.
99
+ - pred_latent (torch.Tensor): Predicted latent representations.
100
+ - z (torch.Tensor): Sampled latent vector from VAE.
101
+ """
102
  decoder_inputs_embeds, z_sequence, z = self._img_encode(img, noise)
103
 
104
  output = self.T5(input_ids, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds)
 
108
  mse_loss = self.mse_criterion(vae_latent, z_sequence)
109
  return mse_loss, pred_latent, z
110
 
111
+ def generate(
112
+ self,
113
+ style_text: str,
114
+ gen_text: str,
115
+ style_img: torch.Tensor,
116
+ **kwargs: Any
117
+ ) -> Image.Image:
118
+ """
119
+ Generate an image by combining style and generation texts with a style image.
120
+
121
+ Args:
122
+ style_text (str): Style-related text prompt.
123
+ gen_text (str): Generation-related text prompt.
124
+ style_img (torch.Tensor): Style image tensor. Expected shape is either 3D or 4D.
125
+ **kwargs: Additional keyword arguments.
126
+
127
+ Returns:
128
+ Image.Image: Generated image as a PIL image.
129
+ """
130
+ if style_img.ndim == 3:
131
+ style_img = style_img.unsqueeze(0)
132
+ elif style_img.ndim == 4:
133
+ pass
134
+ else:
135
+ raise ValueError('style_img must be 3D or 4D')
136
 
137
+ texts = [style_text + ' ' + gen_text]
138
+ imgs, _, img_ends = self._generate(texts=texts, imgs=style_img, **kwargs)
139
+ imgs = (imgs + 1) / 2
140
+ return F.to_pil_image(imgs[0, ..., style_img.size(-1):img_ends.item()].detach().cpu())
141
+
142
+ def generate_batch(
143
+ self,
144
+ style_texts: List[str],
145
+ gen_texts: List[str],
146
+ style_imgs: torch.Tensor,
147
+ lengths: List[int],
148
+ **kwargs: Any
149
+ ) -> List[Image.Image]:
150
+ """
151
+ Generate a batch of images from lists of style texts, generation texts, and style images.
152
+
153
+ Args:
154
+ style_texts (List[str]): List of style-related text prompts.
155
+ gen_texts (List[str]): List of generation-related text prompts.
156
+ style_imgs (torch.Tensor): Batch of style images (4D tensor).
157
+ lengths (List[int]): List of lengths corresponding to each image.
158
+ **kwargs: Additional keyword arguments.
159
+
160
+ Returns:
161
+ List[Image.Image]: List of generated images as PIL images.
162
+ """
163
+ assert style_imgs.ndim == 4, 'style_imgs must be 4D'
164
+ assert len(style_texts) == len(style_imgs), 'style_texts and style_imgs must have the same length'
165
+ assert len(gen_texts) == len(style_imgs), 'gen_texts and style_imgs must have the same length'
166
+ texts = [style_text + ' ' + gen_text for style_text, gen_text in zip(style_texts, gen_texts)]
167
+
168
+ imgs, _, img_ends = self._generate(texts=texts, imgs=style_imgs, lengths=lengths, **kwargs)
169
+ imgs = (imgs + 1) / 2
170
+
171
+ out_imgs = []
172
+ for i, end in enumerate(img_ends):
173
+ start = lengths[i]
174
+ out_imgs.append(F.to_pil_image(imgs[i, ..., start:end].detach().cpu()))
175
+ return out_imgs
176
+
177
+ def _generate(
178
+ self,
179
+ texts: Optional[List[str]] = None,
180
+ imgs: Optional[torch.Tensor] = None,
181
+ lengths: Optional[List[int]] = None,
182
+ input_ids: Optional[torch.Tensor] = None,
183
+ z_sequence: Optional[torch.Tensor] = None,
184
+ max_new_tokens: int = 256,
185
+ stopping_criteria: str = 'latent',
186
+ stopping_after: int = 10,
187
+ stopping_patience: int = 1
188
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
189
+ """
190
+ Internal generation routine that combines textual and visual inputs to iteratively generate
191
+ latent representations and decode them into images.
192
+
193
+ Args:
194
+ texts (Optional[List[str]]): List of text prompts.
195
+ imgs (Optional[torch.Tensor]): Input image tensor.
196
+ lengths (Optional[List[int]]): Desired lengths for each image in latent space.
197
+ input_ids (Optional[torch.Tensor]): Tokenized input IDs.
198
+ z_sequence (Optional[torch.Tensor]): Precomputed latent sequence.
199
+ max_new_tokens (int): Maximum tokens to generate.
200
+ stopping_criteria (str): Criteria for stopping ('latent' or 'none').
201
+ stopping_after (int): Number of tokens to check for stopping condition.
202
+ stopping_patience (int): Patience parameter for stopping condition.
203
+
204
+ Returns:
205
+ Tuple containing:
206
+ - imgs (torch.Tensor): Generated images.
207
+ - canvas_sequence (torch.Tensor): Generated latent canvas sequence.
208
+ - img_ends (torch.Tensor): End indices for each generated image.
209
+ """
210
+ assert texts is not None or input_ids is not None, 'Either texts or input_ids must be provided'
211
+ assert imgs is not None or z_sequence is not None, 'Either imgs or z_sequence must be provided'
212
 
213
  if input_ids is None:
214
+ input_ids = self.tokenizer(texts, return_tensors='pt', padding=True).input_ids
215
  input_ids = input_ids.to(self.device)
216
 
217
  if z_sequence is None:
218
+ _, z_sequence, _ = self._img_encode(imgs)
219
+
220
+ if lengths is None:
221
+ lengths = [imgs.size(-1)] * imgs.size(0)
222
+ lengths = torch.tensor(lengths).to(self.device)
223
+ lengths = (lengths / 8).ceil().int()
224
+
225
+ z_sequence_mask = torch.zeros((z_sequence.size(0), lengths.max() + max_new_tokens))
226
+ z_sequence_mask = z_sequence_mask.bool().to(self.device)
227
+ for i, l in enumerate(lengths):
228
+ z_sequence_mask[i, :l] = True
229
 
230
+ canvas_sequence = z_sequence[:, :lengths.min()]
231
  sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
232
  pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0))
233
+ seq_stops = torch.ones(z_sequence.size(0), dtype=torch.int) * -1
234
 
235
+ for token_idx in range(lengths.min(), lengths.max() + max_new_tokens):
236
  if len(z_sequence) == 0:
237
  decoder_inputs_embeds = sos
238
  else:
239
+ decoder_inputs_embeds = self.vae_to_t5(canvas_sequence)
240
  decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
241
  output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
242
  vae_latent = self.t5_to_vae(output.logits[:, -1:])
243
+
244
+ mask_slice = z_sequence_mask[:, token_idx].unsqueeze(-1)
245
+ if token_idx < z_sequence.size(1):
246
+ seq_slice = torch.where(mask_slice, z_sequence[:, token_idx], vae_latent[:, 0])
247
+ else:
248
+ seq_slice = vae_latent[:, 0]
249
+ canvas_sequence = torch.cat([canvas_sequence, seq_slice.unsqueeze(1)], dim=1)
250
 
251
  if stopping_criteria == 'latent':
252
+ similarity = torch.nn.functional.cosine_similarity(canvas_sequence, pad_token, dim=-1)
253
+ windows = (similarity > self.padding_token_threshold).unfold(1, stopping_after, 1)
254
+ window_sums = windows.to(torch.int).sum(dim=2)
255
+
256
+ for i in range(similarity.size(0)):
257
+ idx = (window_sums[i] > (stopping_after - stopping_patience)).nonzero(as_tuple=True)[0]
258
+ if idx.numel() > 0:
259
+ seq_stops[i] = idx[0].item()
260
+
261
+ if torch.all(seq_stops >= 0):
262
  break
263
+ elif stopping_criteria == 'none':
264
+ pass
265
 
266
+ imgs = torch.clamp(self.vae.decode(self.z_rearrange(canvas_sequence)).sample, -1, 1)
267
+ return imgs, canvas_sequence, seq_stops * 8
 
268
 
269
+ def _img_encode(
270
+ self,
271
+ img: torch.Tensor,
272
+ noise: float = 0
273
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
274
+ """
275
+ Encode the input image into a latent representation using the VAE.
276
+
277
+ Args:
278
+ img (torch.Tensor): Input image tensor.
279
+ noise (float): Standard deviation of noise to add to the latent sequence.
280
+
281
+ Returns:
282
+ Tuple containing:
283
+ - decoder_inputs_embeds (torch.Tensor): Embeddings to be used as T5 decoder inputs.
284
+ - z_sequence (torch.Tensor): Rearranged latent sequence from the VAE.
285
+ - z (torch.Tensor): Sampled latent vector from the VAE.
286
+ """
287
  posterior = self.vae.encode(img.float())
288
  z = posterior.latent_dist.sample()
289
  z_sequence = self.query_rearrange(z)
 
297
  decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
298
  return decoder_inputs_embeds, z_sequence, z
299
 
300
+ def compute_padding_token(self) -> None:
301
+ """
302
+ Compute and update the padding token.
303
 
304
+ Raises:
305
+ NotImplementedError: This method must be implemented.
306
+ """
307
  raise NotImplementedError("compute_padding_token not implemented")
308
 
309
+ def compute_padding_token_threshold(self) -> None:
310
+ """
311
+ Compute and update the padding token threshold.
312
 
313
+ Raises:
314
+ NotImplementedError: This method must be implemented.
315
+ """
316
  raise NotImplementedError("compute_padding_token_threshold not implemented")
output.png DELETED
Binary file (19.3 kB)
 
test.png DELETED
Binary file (24.4 kB)