wenxiang guo commited on
Commit
1a66ebf
·
verified ·
1 Parent(s): 45f7a0e

Update ldm/modules/encoders/modules.py

Browse files
Files changed (1) hide show
  1. ldm/modules/encoders/modules.py +187 -187
ldm/modules/encoders/modules.py CHANGED
@@ -1,187 +1,187 @@
1
- import os.path
2
-
3
- import torch
4
- import torch.nn as nn
5
- from functools import partial
6
- from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer
7
- from importlib_resources import files
8
- from ldm.modules.encoders.CLAP.utils import read_config_as_args
9
- from ldm.modules.encoders.CLAP.clap import TextEncoder
10
- from ldm.util import count_params
11
- import numpy as np
12
-
13
-
14
-
15
-
16
- class Video_Feat_Encoder_NoPosembed(nn.Module):
17
- """ Transform the video feat encoder"""
18
-
19
- def __init__(self, origin_dim, embed_dim, seq_len=40):
20
- super().__init__()
21
- self.embedder = nn.Sequential(nn.Linear(origin_dim, embed_dim))
22
-
23
- def forward(self, x):
24
- # Revise the shape here:
25
- x = self.embedder(x) # B x 117 x C
26
-
27
- return x
28
-
29
-
30
-
31
- class Video_Feat_Encoder_NoPosembed_inpaint(Video_Feat_Encoder_NoPosembed):
32
- """ Transform the video feat encoder"""
33
-
34
- def forward(self, x):
35
- # Revise the shape here:
36
- video, spec = x['mix_video_feat'], x['mix_spec']
37
- video = self.embedder(video) # B x 117 x C
38
-
39
- return (video, spec)
40
-
41
- class AbstractEncoder(nn.Module):
42
- def __init__(self):
43
- super().__init__()
44
-
45
- def encode(self, *args, **kwargs):
46
- raise NotImplementedError
47
-
48
-
49
- def disabled_train(self, mode=True):
50
- """Overwrite model.train with this function to make sure train/eval mode
51
- does not change anymore."""
52
- return self
53
-
54
- class FrozenFLANEmbedder(AbstractEncoder):
55
- """Uses the T5 transformer encoder for text"""
56
-
57
- def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77,
58
- freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
59
- super().__init__()
60
-
61
- self.tokenizer = T5Tokenizer.from_pretrained(version)
62
- self.transformer = T5EncoderModel.from_pretrained(version)
63
- self.device = device
64
- self.max_length = max_length # TODO: typical value?
65
- if freeze:
66
- self.freeze()
67
-
68
- def freeze(self):
69
- self.transformer = self.transformer.eval()
70
- # self.train = disabled_train
71
- for param in self.parameters():
72
- param.requires_grad = False
73
-
74
- def forward(self, text):
75
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
76
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
77
- tokens = batch_encoding["input_ids"].to(self.device) # tango的flanT5是不定长度的batch,这里做成定长的batch
78
- outputs = self.transformer(input_ids=tokens)
79
-
80
- z = outputs.last_hidden_state
81
- return z
82
-
83
- def encode(self, text):
84
- return self(text)
85
-
86
-
87
- class FrozenCLAPEmbedder(AbstractEncoder):
88
- """Uses the CLAP transformer encoder for text from microsoft"""
89
-
90
- def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
91
- super().__init__()
92
-
93
- model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
94
- match_params = dict()
95
- for key in list(model_state_dict.keys()):
96
- if 'caption_encoder' in key:
97
- match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
98
-
99
- config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
100
- args = read_config_as_args(config_as_str, is_config_str=True)
101
-
102
- self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
103
- self.caption_encoder = TextEncoder(
104
- args.d_proj, args.text_model, args.transformer_embed_dim
105
- )
106
-
107
- self.max_length = max_length
108
- self.device = device
109
- if freeze: self.freeze()
110
-
111
- print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
112
-
113
- def freeze(self): # only freeze
114
- self.caption_encoder.base = self.caption_encoder.base.eval()
115
- for param in self.caption_encoder.base.parameters():
116
- param.requires_grad = False
117
-
118
- def encode(self, text):
119
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
120
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
121
- tokens = batch_encoding["input_ids"].to(self.device)
122
-
123
- outputs = self.caption_encoder.base(input_ids=tokens)
124
- z = self.caption_encoder.projection(outputs.last_hidden_state)
125
- return z
126
-
127
-
128
- class FrozenCLAPFLANEmbedder(AbstractEncoder):
129
- """Uses the CLAP transformer encoder for text from microsoft"""
130
-
131
- def __init__(self, weights_path, t5version="google/t5-v1_1-large", freeze=True, device="cuda",
132
- max_length=77): # clip-vit-base-patch32
133
- super().__init__()
134
-
135
- model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
136
- match_params = dict()
137
- for key in list(model_state_dict.keys()):
138
- if 'caption_encoder' in key:
139
- match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
140
-
141
- config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
142
- args = read_config_as_args(config_as_str, is_config_str=True)
143
-
144
-
145
- self.clap_tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
146
- self.caption_encoder = TextEncoder(
147
- args.d_proj, args.text_model, args.transformer_embed_dim
148
- )
149
-
150
- self.t5_tokenizer = T5Tokenizer.from_pretrained(t5version)
151
- self.t5_transformer = T5EncoderModel.from_pretrained(t5version)
152
-
153
- self.max_length = max_length
154
- self.to(device=device)
155
- if freeze: self.freeze()
156
-
157
- print(
158
- f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
159
-
160
- def freeze(self):
161
- self.caption_encoder = self.caption_encoder.eval()
162
- for param in self.caption_encoder.parameters():
163
- param.requires_grad = False
164
-
165
- def to(self, device):
166
- self.t5_transformer.to(device)
167
- self.caption_encoder.to(device)
168
- self.device = device
169
-
170
- def encode(self, text):
171
- ori_caption = text['ori_caption']
172
- struct_caption = text['struct_caption']
173
- # print(ori_caption,struct_caption)
174
- clap_batch_encoding = self.clap_tokenizer(ori_caption, truncation=True, max_length=self.max_length,
175
- return_length=True,
176
- return_overflowing_tokens=False, padding="max_length",
177
- return_tensors="pt")
178
- ori_tokens = clap_batch_encoding["input_ids"].to(self.device)
179
- t5_batch_encoding = self.t5_tokenizer(struct_caption, truncation=True, max_length=self.max_length,
180
- return_length=True,
181
- return_overflowing_tokens=False, padding="max_length",
182
- return_tensors="pt")
183
- struct_tokens = t5_batch_encoding["input_ids"].to(self.device)
184
- outputs = self.caption_encoder.base(input_ids=ori_tokens)
185
- z = self.caption_encoder.projection(outputs.last_hidden_state)
186
- z2 = self.t5_transformer(input_ids=struct_tokens).last_hidden_state
187
- return torch.concat([z, z2], dim=1)
 
1
+ import os.path
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from functools import partial
6
+ from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer
7
+ from importlib_resources import files
8
+ from ldm.modules.encoders.CLAP.utils import read_config_as_args
9
+ from ldm.modules.encoders.CLAP.clap import TextEncoder
10
+ from ldm.util import count_params
11
+ import numpy as np
12
+
13
+
14
+
15
+
16
+ class Video_Feat_Encoder_NoPosembed(nn.Module):
17
+ """ Transform the video feat encoder"""
18
+
19
+ def __init__(self, origin_dim, embed_dim, seq_len=40):
20
+ super().__init__()
21
+ self.embedder = nn.Sequential(nn.Linear(origin_dim, embed_dim))
22
+
23
+ def forward(self, x):
24
+ # Revise the shape here:
25
+ x = self.embedder(x) # B x 117 x C
26
+
27
+ return x
28
+
29
+
30
+
31
+ class Video_Feat_Encoder_NoPosembed_inpaint(Video_Feat_Encoder_NoPosembed):
32
+ """ Transform the video feat encoder"""
33
+
34
+ def forward(self, x):
35
+ # Revise the shape here:
36
+ video, spec = x['mix_video_feat'], x['mix_spec']
37
+ video = self.embedder(video) # B x 117 x C
38
+
39
+ return (video, spec)
40
+
41
+ class AbstractEncoder(nn.Module):
42
+ def __init__(self):
43
+ super().__init__()
44
+
45
+ def encode(self, *args, **kwargs):
46
+ raise NotImplementedError
47
+
48
+
49
+ def disabled_train(self, mode=True):
50
+ """Overwrite model.train with this function to make sure train/eval mode
51
+ does not change anymore."""
52
+ return self
53
+
54
+ class FrozenFLANEmbedder(AbstractEncoder):
55
+ """Uses the T5 transformer encoder for text"""
56
+
57
+ def __init__(self, version="google/flan-t5-large", device="cpu", max_length=77,
58
+ freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
59
+ super().__init__()
60
+
61
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
62
+ self.transformer = T5EncoderModel.from_pretrained(version)
63
+ self.device = device
64
+ self.max_length = max_length # TODO: typical value?
65
+ if freeze:
66
+ self.freeze()
67
+
68
+ def freeze(self):
69
+ self.transformer = self.transformer.eval()
70
+ # self.train = disabled_train
71
+ for param in self.parameters():
72
+ param.requires_grad = False
73
+
74
+ def forward(self, text):
75
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
76
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
77
+ tokens = batch_encoding["input_ids"].to(self.device) # tango的flanT5是不定长度的batch,这里做成定长的batch
78
+ outputs = self.transformer(input_ids=tokens)
79
+
80
+ z = outputs.last_hidden_state
81
+ return z
82
+
83
+ def encode(self, text):
84
+ return self(text)
85
+
86
+
87
+ class FrozenCLAPEmbedder(AbstractEncoder):
88
+ """Uses the CLAP transformer encoder for text from microsoft"""
89
+
90
+ def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
91
+ super().__init__()
92
+
93
+ model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
94
+ match_params = dict()
95
+ for key in list(model_state_dict.keys()):
96
+ if 'caption_encoder' in key:
97
+ match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
98
+
99
+ config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
100
+ args = read_config_as_args(config_as_str, is_config_str=True)
101
+
102
+ self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
103
+ self.caption_encoder = TextEncoder(
104
+ args.d_proj, args.text_model, args.transformer_embed_dim
105
+ )
106
+
107
+ self.max_length = max_length
108
+ self.device = device
109
+ if freeze: self.freeze()
110
+
111
+ print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
112
+
113
+ def freeze(self): # only freeze
114
+ self.caption_encoder.base = self.caption_encoder.base.eval()
115
+ for param in self.caption_encoder.base.parameters():
116
+ param.requires_grad = False
117
+
118
+ def encode(self, text):
119
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
120
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
121
+ tokens = batch_encoding["input_ids"].to(self.device)
122
+
123
+ outputs = self.caption_encoder.base(input_ids=tokens)
124
+ z = self.caption_encoder.projection(outputs.last_hidden_state)
125
+ return z
126
+
127
+
128
+ class FrozenCLAPFLANEmbedder(AbstractEncoder):
129
+ """Uses the CLAP transformer encoder for text from microsoft"""
130
+
131
+ def __init__(self, weights_path, t5version="google/t5-v1_1-large", freeze=True, device="cuda",
132
+ max_length=77): # clip-vit-base-patch32
133
+ super().__init__()
134
+
135
+ model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
136
+ match_params = dict()
137
+ for key in list(model_state_dict.keys()):
138
+ if 'caption_encoder' in key:
139
+ match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
140
+
141
+ config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
142
+ args = read_config_as_args(config_as_str, is_config_str=True)
143
+
144
+
145
+ self.clap_tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
146
+ self.caption_encoder = TextEncoder(
147
+ args.d_proj, args.text_model, args.transformer_embed_dim
148
+ )
149
+
150
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(t5version)
151
+ self.t5_transformer = T5EncoderModel.from_pretrained(t5version)
152
+
153
+ self.max_length = max_length
154
+ self.to(device=device)
155
+ if freeze: self.freeze()
156
+
157
+ print(
158
+ f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
159
+
160
+ def freeze(self):
161
+ self.caption_encoder = self.caption_encoder.eval()
162
+ for param in self.caption_encoder.parameters():
163
+ param.requires_grad = False
164
+
165
+ def to(self, device):
166
+ self.t5_transformer.to(device)
167
+ self.caption_encoder.to(device)
168
+ self.device = device
169
+
170
+ def encode(self, text):
171
+ ori_caption = text['ori_caption']
172
+ struct_caption = text['struct_caption']
173
+ # print(ori_caption,struct_caption)
174
+ clap_batch_encoding = self.clap_tokenizer(ori_caption, truncation=True, max_length=self.max_length,
175
+ return_length=True,
176
+ return_overflowing_tokens=False, padding="max_length",
177
+ return_tensors="pt")
178
+ ori_tokens = clap_batch_encoding["input_ids"].to(self.device)
179
+ t5_batch_encoding = self.t5_tokenizer(struct_caption, truncation=True, max_length=self.max_length,
180
+ return_length=True,
181
+ return_overflowing_tokens=False, padding="max_length",
182
+ return_tensors="pt")
183
+ struct_tokens = t5_batch_encoding["input_ids"].to(self.device)
184
+ outputs = self.caption_encoder.base(input_ids=ori_tokens)
185
+ z = self.caption_encoder.projection(outputs.last_hidden_state)
186
+ z2 = self.t5_transformer(input_ids=struct_tokens).last_hidden_state
187
+ return torch.concat([z, z2], dim=1)