|
|
|
|
|
import torch |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision.datasets import MNIST |
|
from torchvision.transforms import Compose, ToTensor, Resize |
|
import torchvision.transforms as T |
|
|
|
|
|
|
|
import numpy as np |
|
import pandas as pd |
|
|
|
|
|
def get_mnist_dataset(args:any) -> DataLoader: |
|
|
|
|
|
if args.dataset == 'normal': |
|
|
|
print(args.download) |
|
transform = Compose([ToTensor(), Resize(args.image_size), lambda x: x > 0.5]) |
|
train_dataset = MNIST(root=args.data_root, download=True, transform=transform, train=True) |
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
num_workers=args.workers, |
|
batch_size=args.batch_size, |
|
shuffle=True, |
|
pin_memory=True, |
|
drop_last=True |
|
) |
|
|
|
elif args.dataset == 'sequence': |
|
|
|
transform = Compose([ToTensor(), Resize(args.image_size), lambda x: x > 0.5, T.Lambda(lambda x: torch.flatten(x).unsqueeze(0))]) |
|
train_dataset = MNIST(root=args.data_root, download=True, transform=transform, train=True) |
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
num_workers=args.workers, |
|
batch_size=args.batch_size, |
|
shuffle=True, |
|
pin_memory=True, |
|
drop_last=True |
|
) |
|
|
|
else: |
|
print('Please picker either normal or sequence') |
|
quit() |
|
|
|
return train_dataloader |
|
|
|
|
|
|
|
|
|
' Protein preprocessing tools ' |
|
|
|
|
|
def pad_ends( |
|
seqs: list, |
|
max_seq_length: int |
|
) -> list: |
|
|
|
padded_seqs = [] |
|
for seq in seqs: |
|
|
|
seq_length = len(seq) |
|
|
|
pad_need = max_seq_length - seq_length |
|
|
|
seq += '-'*pad_need |
|
|
|
padded_seqs.append(seq) |
|
|
|
return padded_seqs |
|
|
|
|
|
|
|
def create_num_seqs(seq_list: list) -> list: |
|
|
|
|
|
|
|
tokens = [ '<START>', 'A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y','<END>', '-'] |
|
|
|
tokens = tokens + ['X', 'U', 'Z', 'B', 'O'] |
|
token2int = {x:ii for ii, x in enumerate(tokens)} |
|
|
|
|
|
num_seq_list = [] |
|
for seq in seq_list: |
|
num_seq_list.append([token2int[aa] for aa in seq]) |
|
|
|
return num_seq_list |
|
|
|
|
|
def prepare_protein_data( |
|
args: any, |
|
data_dict: dict |
|
) -> ( |
|
list, |
|
list |
|
): |
|
|
|
print([key for key in data_dict.keys()]) |
|
|
|
print('Prepare dataset') |
|
|
|
seq_list = [seq.replace('-','') for seq in data_dict[args.sequence_keyname]] |
|
seq_list = [['<START>'] + list(seq) + ['<END>'] for seq in seq_list] |
|
seq_lens = [len(seq) for seq in seq_list] |
|
|
|
|
|
max_seq_len = int(args.diffusion_steps) |
|
|
|
|
|
valid_indices = [i for i, seq in enumerate(seq_list) if len(seq) <= max_seq_len] |
|
|
|
|
|
filter_seq_list = [seq_list[i] for i in valid_indices] |
|
|
|
max_seq_len = int(args.image_size * args.image_size) |
|
padded_seq_list = pad_ends( |
|
seqs=filter_seq_list, |
|
max_seq_length=max_seq_len |
|
) |
|
num_seq_list = create_num_seqs(padded_seq_list) |
|
|
|
|
|
|
|
if args.facilitator in ['MSE', 'MMD']: |
|
text_emb = data_dict['text_to_protein_embedding'] |
|
elif args.facilitator in ['Default']: |
|
text_emb = data_dict['text_embedding'] |
|
else: |
|
raise ValueError(f"Unexpected value for 'facilitator': {args.facilitator}") |
|
|
|
text_emb = [text_emb[i] for i in valid_indices] |
|
|
|
|
|
print('Finished preparing dataset') |
|
|
|
|
|
|
|
return ( |
|
num_seq_list, |
|
text_emb |
|
) |
|
|
|
|
|
class protein_dataset(Dataset): |
|
""" |
|
|
|
Sequence dataloader |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_seq_list: list, |
|
text_emb: torch.Tensor |
|
): |
|
|
|
if not torch.is_tensor(num_seq_list): |
|
self.num_seqs = torch.tensor(num_seq_list).float() |
|
|
|
else: |
|
pass |
|
|
|
self.text_emb = text_emb |
|
|
|
|
|
|
|
|
|
def __len__(self,): |
|
""" |
|
number of samples total |
|
""" |
|
return len(self.num_seqs) |
|
|
|
def __getitem__(self, idx: any) -> ( |
|
torch.FloatTensor, |
|
torch.FloatTensor |
|
): |
|
|
|
""" |
|
extract adn return the data batch samples |
|
""" |
|
|
|
|
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
|
|
num_seqs = self.num_seqs[idx] |
|
|
|
text_emb = self.text_emb[idx] |
|
|
|
return ( |
|
num_seqs, |
|
text_emb |
|
) |
|
|