wenxiang guo commited on
Commit
14db764
·
verified ·
1 Parent(s): 8c70e9a

Update ldm/modules/encoders/modules.py

Browse files
Files changed (1) hide show
  1. ldm/modules/encoders/modules.py +4 -2
ldm/modules/encoders/modules.py CHANGED
@@ -56,6 +56,7 @@ class FrozenFLANEmbedder(AbstractEncoder):
56
 
57
  def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77,
58
  freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
 
59
  super().__init__()
60
 
61
  self.tokenizer = T5Tokenizer.from_pretrained(version)
@@ -88,8 +89,8 @@ class FrozenCLAPEmbedder(AbstractEncoder):
88
  """Uses the CLAP transformer encoder for text from microsoft"""
89
 
90
  def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
 
91
  super().__init__()
92
-
93
  model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
94
  match_params = dict()
95
  for key in list(model_state_dict.keys()):
@@ -103,7 +104,7 @@ class FrozenCLAPEmbedder(AbstractEncoder):
103
  self.caption_encoder = TextEncoder(
104
  args.d_proj, args.text_model, args.transformer_embed_dim
105
  )
106
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
107
  self.max_length = max_length
108
  self.device = device
109
  if freeze: self.freeze()
@@ -130,6 +131,7 @@ class FrozenCLAPFLANEmbedder(AbstractEncoder):
130
 
131
  def __init__(self, weights_path, t5version="google/t5-v1_1-large", freeze=True, device="cuda",
132
  max_length=77): # clip-vit-base-patch32
 
133
  super().__init__()
134
 
135
  model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
 
56
 
57
  def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77,
58
  freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
59
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
60
  super().__init__()
61
 
62
  self.tokenizer = T5Tokenizer.from_pretrained(version)
 
89
  """Uses the CLAP transformer encoder for text from microsoft"""
90
 
91
  def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
92
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
93
  super().__init__()
 
94
  model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
95
  match_params = dict()
96
  for key in list(model_state_dict.keys()):
 
104
  self.caption_encoder = TextEncoder(
105
  args.d_proj, args.text_model, args.transformer_embed_dim
106
  )
107
+
108
  self.max_length = max_length
109
  self.device = device
110
  if freeze: self.freeze()
 
131
 
132
  def __init__(self, weights_path, t5version="google/t5-v1_1-large", freeze=True, device="cuda",
133
  max_length=77): # clip-vit-base-patch32
134
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
135
  super().__init__()
136
 
137
  model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']