Vittorio Pippi commited on
Commit
19f124c
·
1 Parent(s): 3e42bd1

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ checkpoints
2
+ test.py
3
+ model.py
__pycache__/configuration_emuru.cpython-311.pyc ADDED
Binary file (1.14 kB). View file
 
__pycache__/modeling_emuru.cpython-311.pyc ADDED
Binary file (3.87 kB). View file
 
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Emuru"
4
+ ],
5
+ "model_type": "emuru",
6
+ "slices_per_query": 1,
7
+ "t5_config": "google-t5/t5-large",
8
+ "tokenizer_config": "google/byt5-small",
9
+ "torch_dtype": "float32",
10
+ "transformers_version": "4.38.2",
11
+ "vae_channels": 1,
12
+ "vae_config": "blowing-up-groundhogs/emuru_vae"
13
+ }
configuration_emuru.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class EmuruConfig(PretrainedConfig):
4
+ model_type = "emuru" # Unique identifier for your model
5
+
6
+ def __init__(self,
7
+ t5_config='google-t5/t5-large',
8
+ vae_config='blowing-up-groundhogs/emuru_vae',
9
+ tokenizer_config='google/byt5-small',
10
+ slices_per_query=1,
11
+ vae_channels=1,
12
+ **kwargs):
13
+ super().__init__(**kwargs)
14
+ self.t5_config = t5_config
15
+ self.vae_config = vae_config
16
+ self.tokenizer_config = tokenizer_config
17
+ self.slices_per_query = slices_per_query
18
+ self.vae_channels = vae_channels
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b1c1d034a3c8b158d4f656bdccf4498816c8499300da30b764fa9f290f33f7f
3
+ size 2876698952
modeling_emuru.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_emuru.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
5
+ from configuration_emuru import EmuruConfig
6
+ from diffusers import AutoencoderKL
7
+ from einops.layers.torch import Rearrange
8
+ from einops import rearrange, repeat
9
+
10
+ class Emuru(PreTrainedModel):
11
+ config_class = EmuruConfig # Link to your configuration
12
+
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ # Initialize the tokenizer (if you want it as part of your model)
16
+ self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_config)
17
+
18
+ # Load T5 using the provided filename from config
19
+ t5_config = T5Config.from_pretrained(config.t5_config)
20
+ t5_config.vocab_size = len(self.tokenizer)
21
+ self.T5 = T5ForConditionalGeneration(t5_config)
22
+ self.T5.lm_head = nn.Identity()
23
+ self.sos = nn.Embedding(1, t5_config.d_model)
24
+
25
+ vae_latent_size = 8 * config.vae_channels * config.slices_per_query
26
+ self.query_emb = nn.Linear(vae_latent_size, t5_config.d_model)
27
+ self.t5_to_vae = nn.Linear(t5_config.d_model, vae_latent_size, bias=False)
28
+
29
+ self.padding_token = nn.Parameter(torch.empty((1, vae_latent_size)), requires_grad=False)
30
+ self.padding_token_threshold = nn.Parameter(torch.empty(1), requires_grad=False)
31
+
32
+ # Load VAE
33
+ self.vae = AutoencoderKL.from_pretrained(config.vae_config)
34
+ self.set_training(self.vae, False)
35
+
36
+ # Define the rearrange layers
37
+ self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
38
+ self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
39
+ self.z_rearrange_eval = Rearrange('w b (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
40
+
41
+ # Define your loss functions
42
+ self.mse_criterion = nn.MSELoss()
43
+
44
+ # Initialize weights following Hugging Face conventions (if needed)
45
+ self.init_weights()
46
+
47
+ def set_training(self, model, training):
48
+ model.train() if training else model.eval()
49
+ for param in model.parameters():
50
+ param.requires_grad = training
51
+
52
+ # --- Implement the rest of your methods ---
53
+ # For example, _img_encode, forward, generate, etc.
54
+ # You can largely port your existing code here, making sure that:
55
+ # - The forward method returns a dictionary with your losses and outputs.
56
+ # - You use the Hugging Face methods for saving/loading weights.
57
+
58
+ def forward(self, text=None, img=None, input_ids=None, attention_mask=None, length=None, noise=0):
59
+ # Your forward implementation (port over from your original code)
60
+ # Make sure to call self._img_encode(img, noise) and use self.T5, etc.
61
+ ...
62
+
63
+ # Add other methods (forward_recurrent, generate, etc.) as needed.