Vittorio Pippi commited on
Commit
69ab272
·
1 Parent(s): 0021de3

Initial commit

Browse files
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  checkpoints
2
  test.py
3
- model.py
 
 
 
1
  checkpoints
2
  test.py
3
+ model.py
4
+ sample.png
5
+ visual_prompting.py
__pycache__/modeling_emuru.cpython-311.pyc CHANGED
Binary files a/__pycache__/modeling_emuru.cpython-311.pyc and b/__pycache__/modeling_emuru.cpython-311.pyc differ
 
modeling_emuru.py CHANGED
@@ -2,7 +2,8 @@
2
  import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
5
- from .configuration_emuru import EmuruConfig
 
6
  from diffusers import AutoencoderKL
7
  from einops.layers.torch import Rearrange
8
  from einops import rearrange, repeat
@@ -43,6 +44,7 @@ class Emuru(PreTrainedModel):
43
  # Initialize weights following Hugging Face conventions (if needed)
44
  self.init_weights()
45
 
 
46
  def set_training(self, model, training):
47
  model.train() if training else model.eval()
48
  for param in model.parameters():
@@ -53,7 +55,8 @@ class Emuru(PreTrainedModel):
53
  # You can largely port your existing code here, making sure that:
54
  # - The forward method returns a dictionary with your losses and outputs.
55
  # - You use the Hugging Face methods for saving/loading weights.
56
-
 
57
  def forward(self, img=None, input_ids=None, attention_mask=None, noise=0, **kwargs):
58
  decoder_inputs_embeds, z_sequence, z = self._img_encode(img, noise)
59
 
@@ -63,11 +66,98 @@ class Emuru(PreTrainedModel):
63
 
64
  mse_loss = self.mse_criterion(vae_latent, z_sequence)
65
  return mse_loss, pred_latent, z
 
 
 
 
 
 
 
 
 
 
66
 
67
- def generate(self, text=None, img=None, max_length=128, noise=0):
68
- # Your generate implementation (port over from your original code)
69
- # Make sure to call self._img_encode(img, noise) and use self.T5, etc.
70
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def _img_encode(self, img, noise=0):
73
  posterior = self.vae.encode(img.float())
@@ -78,15 +168,15 @@ class Emuru(PreTrainedModel):
78
  if noise > 0:
79
  noise_sequence = z_sequence + torch.randn_like(z_sequence) * noise
80
 
81
- decoder_inputs_embeds = self.query_emb(noise_sequence)
82
  sos = repeat(self.sos.weight, '1 d -> b 1 d', b=decoder_inputs_embeds.size(0))
83
  decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
84
  return decoder_inputs_embeds, z_sequence, z
85
 
 
86
  def compute_padding_token(self):
87
- # Your compute_padding_token implementation (port over from your original code)
88
- ...
89
 
90
  def compute_padding_token_threshold(self):
91
- # Your compute_padding_token_threshold implementation (port over from your original code)
92
- ...
 
2
  import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
5
+ from configuration_emuru import EmuruConfig
6
+ # from .configuration_emuru import EmuruConfig
7
  from diffusers import AutoencoderKL
8
  from einops.layers.torch import Rearrange
9
  from einops import rearrange, repeat
 
44
  # Initialize weights following Hugging Face conventions (if needed)
45
  self.init_weights()
46
 
47
+
48
  def set_training(self, model, training):
49
  model.train() if training else model.eval()
50
  for param in model.parameters():
 
55
  # You can largely port your existing code here, making sure that:
56
  # - The forward method returns a dictionary with your losses and outputs.
57
  # - You use the Hugging Face methods for saving/loading weights.
58
+
59
+
60
  def forward(self, img=None, input_ids=None, attention_mask=None, noise=0, **kwargs):
61
  decoder_inputs_embeds, z_sequence, z = self._img_encode(img, noise)
62
 
 
66
 
67
  mse_loss = self.mse_criterion(vae_latent, z_sequence)
68
  return mse_loss, pred_latent, z
69
+
70
+
71
+ def old_generate(self, text=None, img=None, z_sequence=None, input_ids=None, max_new_tokens=256,
72
+ stopping_criteria='latent', stopping_after=10, stopping_errors=1):
73
+ assert text is not None or input_ids is not None, 'Either text or input_ids must be provided'
74
+ assert img is not None or z_sequence is not None, 'Either img or z_sequence must be provided'
75
+
76
+ if input_ids is None:
77
+ input_ids = self.tokenizer(text, return_tensors='pt', padding=True).input_ids
78
+ input_ids = input_ids.to(next(self.T5.parameters()).device)
79
 
80
+ if z_sequence is None:
81
+ _, z_sequence, _ = self._img_encode(img)
82
+ z_sequence = [z_sequence]
83
+
84
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
85
+ for _ in range(max_new_tokens):
86
+ if len(z_sequence) == 0:
87
+ decoder_inputs_embeds = sos
88
+ else:
89
+ decoder_inputs_embeds = self.vae_to_t5(torch.cat(z_sequence, dim=1))
90
+ decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
91
+ output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
92
+ vae_latent = self.t5_to_vae(output.logits[:, -1:])
93
+ z_sequence.append(vae_latent)
94
+
95
+ if stopping_criteria == 'latent':
96
+ curr_z_sequence = torch.cat(z_sequence, dim=1)
97
+ pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0)).to(decoder_inputs_embeds.device)
98
+ similarity = torch.nn.functional.cosine_similarity(curr_z_sequence, pad_token, dim=-1)
99
+ similarity = similarity[:, -stopping_after:] > self.padding_token_threshold
100
+ if torch.all(similarity.sum(-1) >= (stopping_after - stopping_errors)):
101
+ # z_sequence = [curr_z_sequence[:, :-stopping_after]]
102
+ z_sequence = [curr_z_sequence]
103
+ break
104
+ elif stopping_criteria == 'pixel':
105
+ raise NotImplementedError
106
+
107
+ z_sequence = torch.cat(z_sequence, dim=1)
108
+ img = torch.clamp(self.vae.decode(self.z_rearrange(z_sequence)).sample, -1, 1)
109
+ return img
110
+
111
+
112
+ def generate(self,
113
+ style_text=None,
114
+ gen_text=None,
115
+ style_img=None,
116
+ input_ids=None,
117
+ z_sequence=None,
118
+ max_new_tokens=256,
119
+ stopping_criteria='latent',
120
+ stopping_after=10,
121
+ stopping_patience=1,
122
+ trim_image=True):
123
+ assert (gen_text is not None and style_text is not None) or input_ids is not None, 'Either gen_text and style_text or input_ids must be provided'
124
+ assert style_img is not None or z_sequence is not None, 'Either style_img or z_sequence must be provided'
125
+
126
+ if input_ids is None:
127
+ input_ids = self.tokenizer(gen_text + ' ' + style_text, return_tensors='pt', padding=True).input_ids
128
+ input_ids = input_ids.to(self.device)
129
+
130
+ if z_sequence is None:
131
+ _, z_sequence, _ = self._img_encode(style_img)
132
+ z_sequence = [z_sequence]
133
+
134
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
135
+ pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0))
136
+
137
+ for _ in range(max_new_tokens):
138
+ if len(z_sequence) == 0:
139
+ decoder_inputs_embeds = sos
140
+ else:
141
+ decoder_inputs_embeds = self.vae_to_t5(torch.cat(z_sequence, dim=1))
142
+ decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
143
+ output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
144
+ vae_latent = self.t5_to_vae(output.logits[:, -1:])
145
+ z_sequence.append(vae_latent)
146
+
147
+ if stopping_criteria == 'latent':
148
+ curr_z_sequence = torch.cat(z_sequence, dim=1)
149
+ similarity = torch.nn.functional.cosine_similarity(curr_z_sequence, pad_token, dim=-1)
150
+ similarity = similarity[:, -stopping_after:] > self.padding_token_threshold
151
+ if torch.all(similarity.sum(-1) >= (stopping_after - stopping_patience)):
152
+ z_sequence = [curr_z_sequence[:, :-similarity.sum(-1)]] if trim_image else [curr_z_sequence]
153
+ break
154
+ elif stopping_criteria == 'pixel':
155
+ raise NotImplementedError
156
+
157
+ z_sequence = torch.cat(z_sequence, dim=1)
158
+ img = torch.clamp(self.vae.decode(self.z_rearrange(z_sequence)).sample, -1, 1)
159
+ return img, z_sequence
160
+
161
 
162
  def _img_encode(self, img, noise=0):
163
  posterior = self.vae.encode(img.float())
 
168
  if noise > 0:
169
  noise_sequence = z_sequence + torch.randn_like(z_sequence) * noise
170
 
171
+ decoder_inputs_embeds = self.vae_to_t5(noise_sequence)
172
  sos = repeat(self.sos.weight, '1 d -> b 1 d', b=decoder_inputs_embeds.size(0))
173
  decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
174
  return decoder_inputs_embeds, z_sequence, z
175
 
176
+
177
  def compute_padding_token(self):
178
+ raise NotImplementedError("compute_padding_token not implemented")
179
+
180
 
181
  def compute_padding_token_threshold(self):
182
+ raise NotImplementedError("compute_padding_token_threshold not implemented")
 
output.png ADDED
test.png ADDED