omarperacha's picture
embedings generate
5a79fe4
raw
history blame
1.92 kB
import torch
from torch import nn, cuda, load, device
from ps4_models.classifiers import PS4_Mega, PS4_Conv
def load_trained_model(load_path, model_name='PS4_Mega'):
if model_name.lower() not in ['ps4_conv', 'ps4_mega']:
raise ValueError(f'Model name {model_name} not recognised, please choose from PS4_Conv, PS4_Mega')
model: nn.Module = PS4_Mega() if model_name.lower() == 'ps4_mega' else PS4_Conv()
if load_path != '':
try:
if cuda.is_available():
model.load_state_dict(load(load_path)['model_state_dict'])
else:
model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict'])
print("loded params from", load_path)
except:
raise ImportError(f'No file located at {load_path}, could not load parameters')
print(model)
pytorch_total_params = sum(par.numel() for par in model.parameters() if par.requires_grad)
print(pytorch_total_params)
return model
# MARK: sampling from new sequence
def sample_new_sequence(embs, weights_load_path, model_name='PS4_Mega'):
model = load_trained_model(weights_load_path, model_name)
seq_size = len(embs)
R = embs.view(1, seq_size)
pred_ss = ''
with torch.no_grad():
y_hat = model(R)
probs = torch.softmax(y_hat, 2)
_, ss_preds = torch.max(probs, 2)
for i in range(seq_size):
ss = ss_preds[0][i].item()
ss = ss_tokeniser(ss, reverse=True)
pred_ss += ss
return pred_ss
def ss_tokeniser(ss, reverse=False):
ss_set = ['C', 'T', 'G', 'H', 'S', 'B', 'I', 'E', 'C']
if reverse:
return inverse_ss_tokeniser(ss)
else:
return 0 if (ss == 'P' or ss == ' ') else ss_set.index(ss)
def inverse_ss_tokeniser(ss):
ss_set = ['C', 'T', 'G', 'H', 'S', 'B', 'I', 'E', 'C', 'C']
return ss_set[ss]