Spaces:
Running
on
Zero
Running
on
Zero
wenxiang guo
commited on
Update ldm/modules/encoders/modules.py
Browse files
ldm/modules/encoders/modules.py
CHANGED
@@ -56,6 +56,7 @@ class FrozenFLANEmbedder(AbstractEncoder):
|
|
56 |
|
57 |
def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77,
|
58 |
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
|
|
59 |
super().__init__()
|
60 |
|
61 |
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
@@ -88,8 +89,8 @@ class FrozenCLAPEmbedder(AbstractEncoder):
|
|
88 |
"""Uses the CLAP transformer encoder for text from microsoft"""
|
89 |
|
90 |
def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
|
|
|
91 |
super().__init__()
|
92 |
-
|
93 |
model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
|
94 |
match_params = dict()
|
95 |
for key in list(model_state_dict.keys()):
|
@@ -103,7 +104,7 @@ class FrozenCLAPEmbedder(AbstractEncoder):
|
|
103 |
self.caption_encoder = TextEncoder(
|
104 |
args.d_proj, args.text_model, args.transformer_embed_dim
|
105 |
)
|
106 |
-
|
107 |
self.max_length = max_length
|
108 |
self.device = device
|
109 |
if freeze: self.freeze()
|
@@ -130,6 +131,7 @@ class FrozenCLAPFLANEmbedder(AbstractEncoder):
|
|
130 |
|
131 |
def __init__(self, weights_path, t5version="google/t5-v1_1-large", freeze=True, device="cuda",
|
132 |
max_length=77): # clip-vit-base-patch32
|
|
|
133 |
super().__init__()
|
134 |
|
135 |
model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
|
|
|
56 |
|
57 |
def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77,
|
58 |
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
59 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
60 |
super().__init__()
|
61 |
|
62 |
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
|
|
89 |
"""Uses the CLAP transformer encoder for text from microsoft"""
|
90 |
|
91 |
def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
|
92 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
93 |
super().__init__()
|
|
|
94 |
model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
|
95 |
match_params = dict()
|
96 |
for key in list(model_state_dict.keys()):
|
|
|
104 |
self.caption_encoder = TextEncoder(
|
105 |
args.d_proj, args.text_model, args.transformer_embed_dim
|
106 |
)
|
107 |
+
|
108 |
self.max_length = max_length
|
109 |
self.device = device
|
110 |
if freeze: self.freeze()
|
|
|
131 |
|
132 |
def __init__(self, weights_path, t5version="google/t5-v1_1-large", freeze=True, device="cuda",
|
133 |
max_length=77): # clip-vit-base-patch32
|
134 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
135 |
super().__init__()
|
136 |
|
137 |
model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
|