File size: 5,568 Bytes
c865888 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Resize
import torchvision.transforms as T
#from numba import jit
import numpy as np
import pandas as pd
def get_mnist_dataset(args:any) -> DataLoader:
if args.dataset == 'normal':
print(args.download)
transform = Compose([ToTensor(), Resize(args.image_size), lambda x: x > 0.5])
train_dataset = MNIST(root=args.data_root, download=True, transform=transform, train=True)
train_dataloader = DataLoader(
train_dataset,
num_workers=args.workers,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
drop_last=True
)
elif args.dataset == 'sequence':
transform = Compose([ToTensor(), Resize(args.image_size), lambda x: x > 0.5, T.Lambda(lambda x: torch.flatten(x).unsqueeze(0))])
train_dataset = MNIST(root=args.data_root, download=True, transform=transform, train=True)
train_dataloader = DataLoader(
train_dataset,
num_workers=args.workers,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
drop_last=True
)
else:
print('Please picker either normal or sequence')
quit()
return train_dataloader
' Protein preprocessing tools '
#@jit(nopython=True)
def pad_ends(
seqs: list,
max_seq_length: int
) -> list:
padded_seqs = [] # add padded gaps at the end of each sequence
for seq in seqs:
seq_length = len(seq)
# number of padded tokens
pad_need = max_seq_length - seq_length
# add number of padded tokens to the end
seq += '-'*pad_need
padded_seqs.append(seq)
return padded_seqs
# create numerical represented sqeuences
def create_num_seqs(seq_list: list) -> list:
# tokenizer
#tokens = ['*', '<START>', 'A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y','<END>', '-']
tokens = [ '<START>', 'A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y','<END>', '-']
# needed to lose these to the token list
tokens = tokens + ['X', 'U', 'Z', 'B', 'O']
token2int = {x:ii for ii, x in enumerate(tokens)}
# empty list to hold num rep. seqs.
num_seq_list = []
for seq in seq_list:
num_seq_list.append([token2int[aa] for aa in seq])
return num_seq_list
# prepare the protein sequences
def prepare_protein_data(
args: any,
data_dict: dict
) -> (
list,
list
):
print([key for key in data_dict.keys()])
print('Prepare dataset')
# prepare sequences
seq_list = [seq.replace('-','') for seq in data_dict[args.sequence_keyname]]
seq_list = [['<START>'] + list(seq) + ['<END>'] for seq in seq_list]
seq_lens = [len(seq) for seq in seq_list]
# Determine the maximum sequence length based on context window size
max_seq_len = int(args.diffusion_steps)
# Get indices of sequences that meet the criteria
valid_indices = [i for i, seq in enumerate(seq_list) if len(seq) <= max_seq_len]
# Filter num_seq_list based on these indices
filter_seq_list = [seq_list[i] for i in valid_indices]
max_seq_len = int(args.image_size * args.image_size)
padded_seq_list = pad_ends(
seqs=filter_seq_list,
max_seq_length=max_seq_len
)
num_seq_list = create_num_seqs(padded_seq_list) # numerical representations
# prepare class labels
#class_label_list = df.label.values.tolist()
if args.facilitator in ['MSE', 'MMD']:
text_emb = data_dict['text_to_protein_embedding']
elif args.facilitator in ['Default']:
text_emb = data_dict['text_embedding']
else:
raise ValueError(f"Unexpected value for 'facilitator': {args.facilitator}")
text_emb = [text_emb[i] for i in valid_indices]
# prune sequence and texts out based on length
print('Finished preparing dataset')
#
return (
num_seq_list,
text_emb
)
class protein_dataset(Dataset):
"""
Sequence dataloader
"""
def __init__(
self,
num_seq_list: list,
text_emb: torch.Tensor
):
if not torch.is_tensor(num_seq_list):
self.num_seqs = torch.tensor(num_seq_list).float()
else:
pass
self.text_emb = text_emb
#if not torch.is_tensor(class_label_list):
# self.class_label = torch.tensor(class_label_list).float()
def __len__(self,):
"""
number of samples total
"""
return len(self.num_seqs)
def __getitem__(self, idx: any) -> (
torch.FloatTensor,
torch.FloatTensor
):
"""
extract adn return the data batch samples
"""
# convert and return the data batch samples
if torch.is_tensor(idx):
idx = idx.tolist()
# sequences
num_seqs = self.num_seqs[idx]
# class labels
text_emb = self.text_emb[idx]
return (
num_seqs,
text_emb
)
|