Update model/openlamm.py
Browse files- model/openlamm.py +5 -1
model/openlamm.py
CHANGED
@@ -160,7 +160,11 @@ class LAMMPEFTModel(nn.Module):
|
|
160 |
encoder_pretrain = args['encoder_pretrain'] if 'encoder_pretrain' in args else 'clip'
|
161 |
self.encoder_pretrain = encoder_pretrain
|
162 |
assert encoder_pretrain in ['imagebind', 'clip', 'epcl'], f'Encoder_pretrain: {encoder_pretrain} Not Implemented'
|
163 |
-
|
|
|
|
|
|
|
|
|
164 |
vicuna_ckpt_path = args['vicuna_ckpt_path']
|
165 |
|
166 |
system_header = args['system_header'] if 'system_header' in args else False
|
|
|
160 |
encoder_pretrain = args['encoder_pretrain'] if 'encoder_pretrain' in args else 'clip'
|
161 |
self.encoder_pretrain = encoder_pretrain
|
162 |
assert encoder_pretrain in ['imagebind', 'clip', 'epcl'], f'Encoder_pretrain: {encoder_pretrain} Not Implemented'
|
163 |
+
if not encoder_pretrain == 'clip' or os.path.isfile(args['encoder_ckpt_path']):
|
164 |
+
encoder_ckpt_path = args['encoder_ckpt_path']
|
165 |
+
elif not os.path.isfile(args['encoder_ckpt_path']):
|
166 |
+
encoder_ckpt_path = 'ViT-L/14'
|
167 |
+
|
168 |
vicuna_ckpt_path = args['vicuna_ckpt_path']
|
169 |
|
170 |
system_header = args['system_header'] if 'system_header' in args else False
|