Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn, cuda, load, device | |
from ps4_models.classifiers import PS4_Mega, PS4_Conv | |
def load_trained_model(load_path, model_name='PS4_Mega'): | |
if model_name.lower() not in ['ps4_conv', 'ps4_mega']: | |
raise ValueError(f'Model name {model_name} not recognised, please choose from PS4_Conv, PS4_Mega') | |
model: nn.Module = PS4_Mega() if model_name.lower() == 'ps4_mega' else PS4_Conv() | |
if load_path != '': | |
try: | |
if cuda.is_available(): | |
model.load_state_dict(load(load_path)['model_state_dict']) | |
else: | |
model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict']) | |
print("loded params from", load_path) | |
except: | |
raise ImportError(f'No file located at {load_path}, could not load parameters') | |
print(model) | |
pytorch_total_params = sum(par.numel() for par in model.parameters() if par.requires_grad) | |
print(pytorch_total_params) | |
return model | |
# MARK: sampling from new sequence | |
def sample_new_sequence(embs, weights_load_path, model_name='PS4_Mega'): | |
model = load_trained_model(weights_load_path, model_name) | |
seq_size = len(embs) | |
R = embs.view(1, seq_size) | |
pred_ss = '' | |
with torch.no_grad(): | |
y_hat = model(R) | |
probs = torch.softmax(y_hat, 2) | |
_, ss_preds = torch.max(probs, 2) | |
for i in range(seq_size): | |
ss = ss_preds[0][i].item() | |
ss = ss_tokeniser(ss, reverse=True) | |
pred_ss += ss | |
return pred_ss | |
def ss_tokeniser(ss, reverse=False): | |
ss_set = ['C', 'T', 'G', 'H', 'S', 'B', 'I', 'E', 'C'] | |
if reverse: | |
return inverse_ss_tokeniser(ss) | |
else: | |
return 0 if (ss == 'P' or ss == ' ') else ss_set.index(ss) | |
def inverse_ss_tokeniser(ss): | |
ss_set = ['C', 'T', 'G', 'H', 'S', 'B', 'I', 'E', 'C', 'C'] | |
return ss_set[ss] |