Vittorio Pippi
commited on
Commit
·
0021de3
1
Parent(s):
482b875
Initial commit
Browse files- __pycache__/modeling_emuru.cpython-311.pyc +0 -0
- config.json +3 -2
- model.safetensors +1 -1
- modeling_emuru.py +35 -6
__pycache__/modeling_emuru.cpython-311.pyc
CHANGED
Binary files a/__pycache__/modeling_emuru.cpython-311.pyc and b/__pycache__/modeling_emuru.cpython-311.pyc differ
|
|
config.json
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
{
|
|
|
2 |
"architectures": [
|
3 |
"Emuru"
|
4 |
],
|
5 |
"auto_map": {
|
6 |
-
"AutoConfig": "configuration_emuru.EmuruConfig",
|
7 |
-
"AutoModel": "modeling_emuru.Emuru"
|
8 |
},
|
9 |
"model_type": "emuru",
|
10 |
"slices_per_query": 1,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "blowing-up-groundhogs/emuru",
|
3 |
"architectures": [
|
4 |
"Emuru"
|
5 |
],
|
6 |
"auto_map": {
|
7 |
+
"AutoConfig": "blowing-up-groundhogs/emuru--configuration_emuru.EmuruConfig",
|
8 |
+
"AutoModel": "blowing-up-groundhogs/emuru--modeling_emuru.Emuru"
|
9 |
},
|
10 |
"model_type": "emuru",
|
11 |
"slices_per_query": 1,
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 2876698952
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59be7ae3ad22ca92fb41555f87b2c6051a645d805fd471e415b7e5f68369a9a2
|
3 |
size 2876698952
|
modeling_emuru.py
CHANGED
@@ -23,7 +23,7 @@ class Emuru(PreTrainedModel):
|
|
23 |
self.sos = nn.Embedding(1, t5_config.d_model)
|
24 |
|
25 |
vae_latent_size = 8 * config.vae_channels * config.slices_per_query
|
26 |
-
self.
|
27 |
self.t5_to_vae = nn.Linear(t5_config.d_model, vae_latent_size, bias=False)
|
28 |
|
29 |
self.padding_token = nn.Parameter(torch.empty((1, vae_latent_size)), requires_grad=False)
|
@@ -36,7 +36,6 @@ class Emuru(PreTrainedModel):
|
|
36 |
# Define the rearrange layers
|
37 |
self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
|
38 |
self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
|
39 |
-
self.z_rearrange_eval = Rearrange('w b (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
|
40 |
|
41 |
# Define your loss functions
|
42 |
self.mse_criterion = nn.MSELoss()
|
@@ -55,9 +54,39 @@ class Emuru(PreTrainedModel):
|
|
55 |
# - The forward method returns a dictionary with your losses and outputs.
|
56 |
# - You use the Hugging Face methods for saving/loading weights.
|
57 |
|
58 |
-
def forward(self,
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
# Make sure to call self._img_encode(img, noise) and use self.T5, etc.
|
61 |
...
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
self.sos = nn.Embedding(1, t5_config.d_model)
|
24 |
|
25 |
vae_latent_size = 8 * config.vae_channels * config.slices_per_query
|
26 |
+
self.vae_to_t5 = nn.Linear(vae_latent_size, t5_config.d_model)
|
27 |
self.t5_to_vae = nn.Linear(t5_config.d_model, vae_latent_size, bias=False)
|
28 |
|
29 |
self.padding_token = nn.Parameter(torch.empty((1, vae_latent_size)), requires_grad=False)
|
|
|
36 |
# Define the rearrange layers
|
37 |
self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
|
38 |
self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
|
|
|
39 |
|
40 |
# Define your loss functions
|
41 |
self.mse_criterion = nn.MSELoss()
|
|
|
54 |
# - The forward method returns a dictionary with your losses and outputs.
|
55 |
# - You use the Hugging Face methods for saving/loading weights.
|
56 |
|
57 |
+
def forward(self, img=None, input_ids=None, attention_mask=None, noise=0, **kwargs):
|
58 |
+
decoder_inputs_embeds, z_sequence, z = self._img_encode(img, noise)
|
59 |
+
|
60 |
+
output = self.T5(input_ids, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds)
|
61 |
+
vae_latent = self.t5_to_vae(output.logits[:, :-1])
|
62 |
+
pred_latent = self.z_rearrange(vae_latent)
|
63 |
+
|
64 |
+
mse_loss = self.mse_criterion(vae_latent, z_sequence)
|
65 |
+
return mse_loss, pred_latent, z
|
66 |
+
|
67 |
+
def generate(self, text=None, img=None, max_length=128, noise=0):
|
68 |
+
# Your generate implementation (port over from your original code)
|
69 |
# Make sure to call self._img_encode(img, noise) and use self.T5, etc.
|
70 |
...
|
71 |
+
|
72 |
+
def _img_encode(self, img, noise=0):
|
73 |
+
posterior = self.vae.encode(img.float())
|
74 |
+
z = posterior.latent_dist.sample()
|
75 |
+
z_sequence = self.query_rearrange(z)
|
76 |
+
|
77 |
+
noise_sequence = z_sequence
|
78 |
+
if noise > 0:
|
79 |
+
noise_sequence = z_sequence + torch.randn_like(z_sequence) * noise
|
80 |
+
|
81 |
+
decoder_inputs_embeds = self.query_emb(noise_sequence)
|
82 |
+
sos = repeat(self.sos.weight, '1 d -> b 1 d', b=decoder_inputs_embeds.size(0))
|
83 |
+
decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
|
84 |
+
return decoder_inputs_embeds, z_sequence, z
|
85 |
+
|
86 |
+
def compute_padding_token(self):
|
87 |
+
# Your compute_padding_token implementation (port over from your original code)
|
88 |
+
...
|
89 |
+
|
90 |
+
def compute_padding_token_threshold(self):
|
91 |
+
# Your compute_padding_token_threshold implementation (port over from your original code)
|
92 |
+
...
|